Exemple #1
0
def _set_network_attributes_from_metadata(revived_obj):
    """Sets attributes recorded in the metadata."""
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
        # pylint:disable=protected-access
        metadata = revived_obj._serialized_attributes['metadata']
        if metadata.get('dtype') is not None:
            revived_obj._set_dtype_policy(metadata['dtype'])
        revived_obj.trainable = metadata['trainable']
Exemple #2
0
    def _init_from_metadata(cls, metadata):
        """Create revived model from metadata stored in the SavedModel proto."""
        revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)

        with trackable.no_automatic_dependency_tracking_scope(revived_obj):
            revived_obj._training_config = metadata.get('training_config')  # pylint:disable=protected-access

        return revived_obj
Exemple #3
0
def _replace_child_layer_functions(layer, serialization_cache):
    """Replaces functions in the children layers with wrapped tf.functions.

  This step allows functions from parent layers to reference the wrapped
  functions from their children layers instead of retracing the ops.

  This function also resets all losses stored in the layer. These are stored in
  the returned dictionary. Use `_restore_child_layer_functions` to restore
  the original attributes.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    Dictionary mapping layer objects -> original functions and losses:
      { Child layer 1: {
          'losses': Original losses,
          'call': Original call function
          'activity_regularizer': Original activity regularizer},
        Child layer 2: ...
      }
  """
    # pylint: disable=protected-access
    original_fns = {}
    for child_layer in _list_all_layers(layer):
        if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
            layer_fns = (
                child_layer._trackable_saved_model_saver.
                _get_serialized_attributes(serialization_cache).functions)
        else:
            layer_fns = (serialization_cache[constants.KERAS_CACHE_KEY]
                         [child_layer].functions)
        if not layer_fns:
            # This indicates either:
            #   - circular dependency, which means the current layer's functions
            #     should be wrapped first.
            #   - Child layer's inputs are not defined, so its functions have not been
            #     wrapped. In this case, no replacement is necessary so move on to the
            #     next child.
            continue
        original_fns[child_layer] = {
            'call': child_layer.call,
            'activity_regularizer': child_layer.activity_regularizer
        }
        with trackable.no_automatic_dependency_tracking_scope(child_layer):
            try:
                child_layer.activity_regularizer = layer_fns.get(
                    'activity_regularizer_fn')
            except AttributeError:
                # Some layers have an unsettable activity regularizer.
                pass
            child_layer.call = utils.use_wrapped_call(
                child_layer,
                layer_fns['call_and_return_conditional_losses'],
                default_training_value=False)
    return original_fns
Exemple #4
0
def _restore_child_layer_functions(original_fns):
  """Restores attributes replaced with `_replace_child_layer_functions`."""
  for child_layer, fns in original_fns.items():
    with trackable.no_automatic_dependency_tracking_scope(child_layer):
      for fn_name, fn in fns.items():
        try:
          setattr(child_layer, fn_name, fn)  # pylint: disable=protected-access
        except AttributeError:
          pass  # In the case of _activity_regularizer, setting the attribute
Exemple #5
0
def _restore_child_layer_functions(original_fns):
    """Restores attributes replaced with `_replace_child_layer_functions`."""
    for child_layer, fns in original_fns.items():
        with trackable.no_automatic_dependency_tracking_scope(child_layer):
            child_layer.call = fns['call']
            try:
                child_layer.activity_regularizer = fns['activity_regularizer']
            except AttributeError:
                pass
Exemple #6
0
def _reset_layer_losses(parent_layer):
  """Resets losses of layer and its sublayers, and returns original losses."""
  losses_dict = {}
  for layer in _list_all_layers(parent_layer) + [parent_layer]:
    losses_dict[layer] = {'losses': layer._losses[:],
                          'eager_losses': layer._eager_losses[:]}
    with trackable.no_automatic_dependency_tracking_scope(layer):
      layer._losses = []
      layer._eager_losses = []
  return losses_dict
Exemple #7
0
def main(_):
    tr1 = base.Trackable()
    v = tf.Variable(1)
    tr1._track_trackable(v, name='tr1_v')
    for _ in range(3):
        trackable(tr1, v)

    tr2 = tracking.AutoTrackable()
    tracked, untracked = tf.Variable(1000), tf.Variable(0)
    tr2.v = tracked
    with base.no_automatic_dependency_tracking_scope(tr2):
        tr2.untracked = untracked
    for _ in range(2):
        autotrackable(tr2, tracked, untracked)

    listing()

    deleting(tr2)

    tr3 = tracking.AutoTrackable()
    br1 = tracking.AutoTrackable()
    br1.v = tf.Variable(5)
    br2 = tracking.AutoTrackable()
    br2.v = tf.Variable(5)
    tr3.br_list = [br1, br2]
    br3 = tracking.AutoTrackable()
    br3.v = tf.Variable(5)
    tr3.br_dict = {'br3': br3}
    containers(tr3)

    tr3.br_dict = {'br1': br1, 'br2': br2, 'br3': br3}
    sharing(tr3)

    mod1 = Module('m1')
    mod1.sub = Module('m2')
    mod1.sub.sub = Module('m3')
    modules(mod1)

    # @tf.function
    # def tracer1():
    #     return mod1()

    # graph(tracer1)

    ins = [tf.keras.Input(shape=(), dtype=tf.int32)]
    lay = Layer(name='l1', sub=Layer(name='l2', sub=Layer(name='l3')))
    outs = [lay(ins)]
    mod2 = tf.keras.Model(name='m2', inputs=ins, outputs=outs)
    models(mod2, lay)

    @tf.function
    def tracer2():
        return mod2(tf.constant([100, 100]))

    graph(tracer2)
Exemple #8
0
 def replace_metric_functions(child_layer, serialized_fns):
   """Replaces metric functions with wrapped functions."""
   original_fns[child_layer] = {
       '__call__': child_layer.__call__,
       'result': child_layer.result,
       'update_state': child_layer.update_state
   }
   with trackable.no_automatic_dependency_tracking_scope(child_layer):
     child_layer.__call__ = serialized_fns['__call__']
     child_layer.result = serialized_fns['result']
     child_layer.update_state = serialized_fns['update_state']
Exemple #9
0
    def _init_from_metadata(cls, metadata):
        """Revives the saved InputLayer from the Metadata."""
        init_args = dict(name=metadata['name'],
                         dtype=metadata['dtype'],
                         sparse=metadata['sparse'],
                         ragged=metadata['ragged'],
                         batch_input_shape=metadata['batch_input_shape'])
        revived_obj = cls(**init_args)
        with trackable.no_automatic_dependency_tracking_scope(revived_obj):
            revived_obj._config = metadata['config']  # pylint:disable=protected-access

        return revived_obj, setattr
Exemple #10
0
 def _capture_init(self, *args, **kwargs):
     """Captures init args and kwargs and stores them into `_saved_kwargs`."""
     if len(args) > len(arg_spec.args) + 1:
         # Error case: more inputs than args.  Call init so that the appropriate
         # error can be raised to the user.
         init(self, *args, **kwargs)
     # Convert to a canonical kwarg format.
     kwargs = tf_inspect.getcallargs(init, self, *args, **kwargs)
     kwargs.pop("self")
     init(self, **kwargs)
     # Avoid auto tracking which prevents keras from tracking layers that are
     # passed as kwargs to the Network.
     with base.no_automatic_dependency_tracking_scope(self):
         setattr(self, "_saved_kwargs", kwargs)
Exemple #11
0
 def _capture_init(self, *args, **kwargs):
   """Captures init args and kwargs and stores them into `_saved_kwargs`."""
   if len(args) > len(arg_spec.args) + 1:
     # Error case: more inputs than args.  Call init so that the appropriate
     # error can be raised to the user.
     init(self, *args, **kwargs)
   for i, arg in enumerate(args):
     # Add +1 to skip `self` in arg_spec.args.
     kwargs[arg_spec.args[1 + i]] = arg
   init(self, **kwargs)
   # Avoid auto tracking which prevents keras from tracking layers that are
   # passed as kwargs to the Network.
   with base.no_automatic_dependency_tracking_scope(self):
     setattr(self, "_saved_kwargs", kwargs)
Exemple #12
0
    def _init_from_metadata(cls, metadata):
        """Create revived network from metadata stored in the SavedModel proto."""
        revived_obj = cls(name=metadata['name'])

        # Store attributes revived from SerializedAttributes in a un-tracked
        # dictionary. The attributes are the ones listed in CommonEndpoints or
        # "keras_api" for keras-specific attributes.
        with trackable.no_automatic_dependency_tracking_scope(revived_obj):
            # pylint:disable=protected-access
            revived_obj._serialized_attributes = {'metadata': metadata}
            _set_network_attributes_from_metadata(revived_obj)
            # pylint:enable=protected-access

        return revived_obj, _revive_setter  # pylint:disable=protected-access
Exemple #13
0
def _set_network_attributes_from_metadata(revived_obj):
    """Sets attributes recorded in the metadata."""
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
        # pylint:disable=protected-access
        metadata = revived_obj._serialized_attributes['metadata']
        if metadata.get('dtype') is not None:
            revived_obj._dtype = metadata['dtype']
        revived_obj.trainable = metadata['trainable']

        revived_obj._expects_training_arg = metadata['expects_training_arg']
        if metadata.get('config') is not None:
            revived_obj._config = metadata['config']

        if metadata.get('activity_regularizer') is not None:
            revived_obj.activity_regularizer = regularizers.deserialize(
                metadata['activity_regularizer'])
Exemple #14
0
 def replace_layer_functions(child_layer, serialized_fns):
   """Replaces layer call and activity regularizer with wrapped functions."""
   original_fns[child_layer] = {
       'call': child_layer.call,
       '_activity_regularizer': child_layer._activity_regularizer
   }
   with trackable.no_automatic_dependency_tracking_scope(child_layer):
     try:
       child_layer._activity_regularizer = serialized_fns.get(
           'activity_regularizer_fn')
     except AttributeError:
       # Some layers have an unsettable activity regularizer.
       pass
     child_layer.call = utils.use_wrapped_call(
         child_layer,
         serialized_fns['call_and_return_conditional_losses'],
         default_training_value=False)
 def call(self,
          inputs,
          training=None,
          sample_shape=(),
          projection=None,
          **kwargs):
   ## NOTE: a 2D inputs is important here, but we don't want to flatten
   # automatically
   if self.flatten_inputs:
     inputs = tf.reshape(inputs, (tf.shape(inputs)[0], -1))
   params = inputs
   ## do not use tf.cond here, it infer the wrong shape when
   # trying to build the layer in Graph mode.
   projection = projection if projection is not None else self.projection
   if projection:
     params = self._dense(params)
     if self.autoregressive:
       params = tf.concat(tf.unstack(params, axis=-1), axis=-1)
   ## applying dropout
   if self._dropout > 0:
     params = bk.dropout(params, p_drop=self._dropout, training=training)
   ## create posterior distribution
   self._posterior_sample_shape = sample_shape
   kw = dict()
   if 'training' in self._posterior_call_kw:
     kw['training'] = training
   if 'sample_shape' in self._posterior_call_kw:
     kw['sample_shape'] = sample_shape
   for k, v in kwargs.items():
     if k in self._posterior_call_kw:
       kw[k] = v
   posterior = self.posterior_layer(params, **kw)
   # tensorflow tries to serialize the distribution, which raise exception
   # when saving the graphs, to avoid this, store it as non-tracking list.
   with trackable.no_automatic_dependency_tracking_scope(self):
     # self._no_dependency
     self._most_recently_built_distribution = posterior
   ## NOTE: all distribution has the method kl_divergence, so we cannot use it
   posterior.KL_divergence = KLdivergence(
     posterior, prior=self.prior,
     sample_shape=None)  # None mean reuse sampled data here
   return posterior
Exemple #16
0
  def _init_from_metadata(cls, metadata):
    """Create revived network from metadata stored in the SavedModel proto."""
    revived_obj = cls(name=metadata['name'])

    # Store attributes revived from SerializedAttributes in a un-tracked
    # dictionary. The attributes are the ones listed in CommonEndpoints or
    # "keras_api" for keras-specific attributes.
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config

      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      # pylint:enable=protected-access

    return revived_obj, _revive_setter  # pylint:disable=protected-access
Exemple #17
0
 def __new__(cls, *args, **kwargs):
   class_tree = [c for c in type.mro(cls) if issubclass(c, keras.Model)][::-1]
   # get default arguments from parents classes
   kw = dict()
   for c in class_tree:
     spec = inspect.getfullargspec(c.__init__)
     if spec.defaults is not None:
       for key, val in zip(spec.args[::-1], spec.defaults[::-1]):
         kw[key] = val
   # update the user provided arguments
   for k, v in zip(spec.args[1:], args):
     kw[k] = v
   kw.update(kwargs)
   # deep copy is necessary here otherwise the init function will modify
   # the arguments
   kw = copy.copy(kw)
   # create the instance
   instance = super().__new__(cls, *args, **kwargs)
   # must make _init_args NonDependency (i.e. nontrackable and won't be
   # saved in save_weights)
   with trackable.no_automatic_dependency_tracking_scope(instance):
     instance._init_args = kw
   return instance
Exemple #18
0
def _add_serialized_attributes(layer, metadata):
    # Store attributes revived from SerializedAttributes in a un-tracked
    # dictionary. The attributes are the ones listed in CommonEndpoints or
    # "keras_api" for keras-specific attributes.
    with trackable.no_automatic_dependency_tracking_scope(layer):
        layer._serialized_attributes = {'metadata': metadata}  # pylint: disable=protected-access
Exemple #19
0
 def wrapper(self, *args, **kwargs):
     self._init_args = args
     with base.no_automatic_dependency_tracking_scope(self):
         setattr(self, "_init_kwargs", kwargs)
     return func(self, *args, **kwargs)
    def build(self, input_shape=None):
        """
        Create any Variables used in the model.

        Parameters
        ----------
        input_shape : list of tuple of int
            Shapes of all the inputs to this layer.
        """

        super().build(input_shape)

        tf.random.set_seed(self.seed)

        def get_initializer(init_vals):
            """Use more efficient initializers if possible to save memory."""

            values, shapes, dtype, minibatched = init_vals

            # initial value of None means that the initial value isn't used, so we
            # can use anything for the initial value
            if all(v is None for v in values):
                initializer = None
            elif all(v is None or np.all(v == 0) for v in values):
                initializer = tf.initializers.zeros()
            elif all(v is None or np.all(v == 1) for v in values):
                initializer = tf.initializers.ones()
            else:
                val = tf.concat(
                    [
                        tf.zeros(s, dtype) if v is None else tf.cast(
                            tf.broadcast_to(v, s), dtype)
                        for v, s in zip(values, shapes)
                    ],
                    axis=1 if minibatched else 0,
                )
                initializer = lambda shape=None, dtype=None: val

            # figure out shape of full concatenated initial value
            shape = list(shapes[0])
            shape[minibatched] = sum(x[minibatched] for x in shapes)

            return initializer, tuple(shape), dtype

        # variables for model parameters
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.base_params = OrderedDict()
        assert len(self.base_params) == 0
        for sig_type in ("trainable", "non_trainable"):
            for k, v in self.base_arrays_init[sig_type].items():
                initializer, shape, dtype = get_initializer(v)
                assert initializer is not None  # params should never be set
                self.base_params[k] = self.add_weight(
                    initializer=initializer,
                    shape=shape,
                    dtype=dtype,
                    trainable=sig_type == "trainable",
                    name="base_params/%s_%s_%s" %
                    (sig_type, dtype, "_".join(str(x) for x in shape)),
                )

        logger.debug("created base param variables")
        logger.debug([str(x) for x in self.base_params.values()])

        # variables to save the internal state of simulation between runs
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.saved_state = OrderedDict()
        for k, v in self.base_arrays_init["state"].items():
            initializer, shape, dtype = get_initializer(v)
            if initializer is not None:
                # don't need to save the state for signals where the initial value
                # doesn't matter
                self.saved_state[k] = tf.Variable(
                    initial_value=lambda: initializer(shape=shape, dtype=dtype
                                                      ),
                    shape=shape,
                    dtype=dtype,
                    trainable=False,
                    name="saved_state/%s_%s" %
                    (dtype, "_".join(str(x) for x in shape)),
                )

        logger.debug("created saved state variables")
        logger.debug([str(x) for x in self.saved_state.values()])

        # call build on any TensorNode Layers

        def unbuild(layer):
            assert layer.built

            # clear any losses attached to layer (they will be recreated in the
            # build step, so we don't want to keep around any losses
            # associated with the previous build)
            # note: not clearing layer._losses, because those are manually added
            # by the user (not created during the build process)
            layer._eager_losses = []
            layer._callable_losses = []

            layer.built = False

            for sub in layer._layers:
                if isinstance(sub, tf.keras.layers.Layer):
                    unbuild(sub)

        layer_ops = [
            op for ops in self.plan
            if isinstance(ops[0], tensor_node.SimTensorNode) for op in ops
            if isinstance(op.func, tf.keras.layers.Layer)
        ]
        weight_gets = []
        weight_sets = []
        for op in layer_ops:
            if op.func in self._layers:
                # already built this layer
                continue

            if op.time is None:
                shape_in = []
            else:
                shape_in = [()]
            if op.input is not None:
                shape_in += [(self.minibatch_size, ) + op.shape_in]
            if len(shape_in) == 1:
                shape_in = shape_in[0]

            if op.func.built:
                # we rebuild the layer (even if it is already built),
                # because we need to build the weights within the TensorGraph
                # context

                # save the weight values so they can be restored
                # exactly inside the tensornode
                weights = op.func.weights
                weight_gets.extend(weights)

                # clear the results of previous build
                unbuild(op.func)
            else:
                weights = None

            with tf.name_scope(op.func.name):
                op.func.build(shape_in)

            if weights is not None:
                weight_sets.extend(op.func.weights)

            # add op func to _layers so that any weights are collected
            self._layers.append(op.func)

        if len(weight_gets) > 0:
            # do all the weight getting/setting in one go, for efficiency reasons
            ctx = (weight_gets[0].graph.as_default() if hasattr(
                weight_gets[0], "graph") else context.eager_mode())
            with ctx:
                weight_vals = tf.keras.backend.batch_get_value(weight_gets)

            tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals))

        # initialize state variables (need to do this manually because we're not
        # adding them to self.weights)
        # note: don't need to do this in eager mode, since variables are
        # initialized on creation
        # TODO: why does this cause problems if it is done before the tensornode
        #  weight get/sets above?
        if not context.executing_eagerly():
            tf.keras.backend.batch_get_value(
                [var.initializer for var in self.saved_state.values()])
    def call(self, inputs, training=None, progress=None, stateful=False):
        """
        Constructs the graph elements to simulate the model.

        Parameters
        ----------
        inputs : list of ``tf.Tensor``
            Input layers/tensors for the network (must match the structure defined in
            `.build_inputs`).
        training : bool
            Whether the network is being run in training or inference mode.  If None,
            uses the symbolic Keras learning phase variable.
        progress : `.utils.ProgressBar`
            Progress bar for construction stage.
        stateful : bool
            Whether or not to build the model to support preserving the internal state
            between executions.

        Returns
        -------
        probe_arrays : list of ``tf.Tensor``
            Tensors representing the output of all the Probes in the network (order
            corresponding to ``self.model.probes``, which is the order the Probes were
            instantiated).
        """

        super().call(inputs, training=training)

        if training == 1 and self.inference_only:
            raise BuildError(
                "TensorGraph was created with inference_only=True; cannot be built "
                "with training=%s" % training)

        tf.random.set_seed(self.seed)

        if progress is None:
            progress = utils.NullProgressBar()

        # reset signaldict
        self.signals.reset()

        # create these constants once here for reuse in different operators
        self.signals.dt = tf.constant(self.dt, self.dtype)
        self.signals.dt_val = self.dt  # store the actual value as well
        self.signals.zero = tf.constant(0, self.dtype)
        self.signals.one = tf.constant(1, self.dtype)

        # set up invariant inputs
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.node_inputs = {}
        for n, inp in zip(self.invariant_inputs, inputs):
            # specify shape of inputs (keras sometimes loses this shape information)
            inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out])

            self.node_inputs[n] = inp

        self.steps_to_run = inputs[-1][0, 0]

        # initialize op builder
        build_config = builder.BuildConfig(
            inference_only=self.inference_only,
            lif_smoothing=config.get_setting(self.model, "lif_smoothing"),
            cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed,
            rng=np.random.RandomState(self.seed),
            training=(tf.keras.backend.learning_phase()
                      if training is None else training),
        )
        self.op_builder = builder.Builder(self.plan, self.signals,
                                          build_config)

        # pre-build stage
        with progress.sub("pre-build stage", max_value=len(self.plan)) as sub:
            self.op_builder.build_pre(sub)

        # build stage
        with progress.sub("build stage",
                          max_value=len(self.plan) * self.unroll) as sub:
            steps_run, probe_arrays, final_internal_state, final_base_params = (
                self._build_loop(sub)
                if self.use_loop else self._build_no_loop(sub))

        # store these so that they can be accessed after the initial build
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.steps_run = steps_run
            self.probe_arrays = probe_arrays
            self.final_internal_state = final_internal_state
            self.final_base_params = final_base_params

        # logging
        logger.info("Number of reads: %d",
                    sum(x for x in self.signals.read_types.values()))
        for x in self.signals.read_types.items():
            logger.info("    %s: %d", *x)
        logger.info("Number of writes: %d",
                    sum(x for x in self.signals.write_types.values()))
        for x in self.signals.write_types.items():
            logger.info("    %s: %d", *x)

        # note: always return steps_run so that the simulation will run for the given
        # number of steps, even if there are no output probes
        outputs = list(probe_arrays.values()) + [steps_run]

        updates = []
        if stateful:
            # update saved state
            updates.extend(
                var.assign(val) for var, val in zip(self.saved_state.values(),
                                                    final_internal_state))

        # if any of the base params have changed (due to online learning rules) then we
        # also need to assign those back to the original variable (so that their
        # values will persist). any parameters targeted by online learning rules
        # will be minibatched, so we only need to update the minibatched params.
        for (key, var), val in zip(self.base_params.items(),
                                   final_base_params):
            try:
                minibatched = self.base_arrays_init["non_trainable"][key][-1]
            except KeyError:
                minibatched = self.base_arrays_init["trainable"][key][-1]

            if minibatched:
                updates.append(var.assign(val))

        logger.info("Number of variable updates: %d", len(updates))

        if len(updates) > 0:
            with tf.control_dependencies(updates):
                outputs = [tf.identity(x) for x in outputs]

        return outputs
Exemple #22
0
  def fit(self,
          train: Union[TensorTypes, DatasetV2],
          valid: Optional[Union[TensorTypes, DatasetV2]] = None,
          valid_freq: int = 500,
          valid_interval: float = 0,
          optimizer: Union[str, List[str], OptimizerV2,
                           List[OptimizerV2]] = 'adam',
          learning_rate: float = 1e-3,
          clipnorm: Optional[float] = None,
          epochs: int = -1,
          max_iter: int = 1000,
          batch_size: int = 32,
          callback: Union[Callback, List[Callback]] = lambda: None,
          compile_graph: bool = True,
          autograph: bool = False,
          logging_interval: float = 3,
          skip_fitted: Union[bool, int] = False,
          terminate_on_nan: bool = True,
          logdir: Optional[str] = None,
          allow_none_gradients: bool = False,
          track_gradients: bool = False) -> Networks:
    """Override the original fit method of keras to provide simplified
    procedure with `Networks.optimize` and `Networks.train_steps`

    Parameters
    ----------
    train : Union[TensorTypes, DatasetV2]
        tensorflow Dataset for training
    valid : Optional[Union[TensorTypes, DatasetV2]], optional
        tensorflow Dataset for validation, by default None
    valid_freq : int, optional
        the frequency, in steps, for performing validation, by default 500
    valid_interval : float, optional
        the interval, in second, for performing validation, by default 0
    optimizer : Union[str, OptimizerV2], optional
        A list of optimizers is accepted in case of multiple steps training.
        If `None`, re-use stored optimizer, raise `RuntimeError` if no
        predefined optimizer found., by default 'adam'
    learning_rate : float, optional
        learning rate for initializing the optimizer, by default 1e-3
    clipnorm : Optional[float], optional
        global L2-norm value for clipping the gradients, by default None
    epochs : int, optional
        maximum number of epochs, by default -1
    max_iter : int, optional
        maximum number of iteration, by default 1000
    batch_size : int, optional
        number of examples for mini-batch, by default 32
    callback : Union[Callback, List[Callback]], optional
        a function or list of functions called every `valid_freq` steps or
        `valid_interval` seconds, by default lambda:None
    compile_graph : bool, optional
        If True, using tensorflow autograph for optimize function (about 2 times
        speed gain), otherwise, run the function in Eager mode (better for
        debugging), by default True
    autograph : bool, optional
        use autograph to compile the function, by default False
    logging_interval : float, optional
        interval, in seconds, for printing logging information, by default 3
    skip_fitted : Union[bool, int], optional
        skip this function if the model if fitted, or fitted for certain amount of
        steps, by default False
    terminate_on_nan : bool, optional
        terminate the training if NaNs returned, by default True
    logdir : Optional[str], optional
        tensorboard logging directory, by default None
    allow_none_gradients : bool, optional
        allow variables with None gradients during training, by default False
    track_gradients : bool, optional
        track and return the metrics includes the gradients' L2-norm for each
        trainable variable, by default False

    Returns
    -------
    Networks
        the network itself for method chaining

    Raises
    ------
    RuntimeError
        if the optimizer is not defined.
    """
    if not self.built:
      raise RuntimeError(
          "build(input_shape) method must be called to initialize "
          "the variables before calling fit")
    batch_size = int(batch_size)
    # validate the dataset
    train = _to_dataset(train, batch_size, self.dtype)
    if valid is not None:
      valid = _to_dataset(valid, batch_size, self.dtype)
    # skip training if model is fitted or reached a number of iteration
    if self.is_fitted and skip_fitted:
      if isinstance(skip_fitted, bool):
        return self
      skip_fitted = int(skip_fitted)
      if int(self.step.numpy()) >= skip_fitted:
        return self
    # create the trainer
    if self.trainer is None:
      with trackable.no_automatic_dependency_tracking_scope(self):
        trainer = Trainer(logdir=logdir)
        self.trainer = trainer
    else:
      trainer = self.trainer
    ## if already called repeat, then no need to repeat more
    if hasattr(train, 'repeat'):
      train = train.repeat(int(epochs))
    ## create the optimizer, turn off tracking so the optimizer
    # won't be saved in save_weights
    if optimizer is not None and self.optimizer is None:
      with trackable.no_automatic_dependency_tracking_scope(self):
        self.optimizer = _to_optimizer(optimizer, learning_rate, clipnorm)
    if self.optimizer is None:
      raise RuntimeError("No optimizer found!")
    ## run early stop and callback
    self.trainer.fit(
        train_ds=train,
        optimize=partial(self.optimize,
                         allow_none_gradients=allow_none_gradients,
                         track_gradients=track_gradients),
        valid_ds=valid,
        valid_freq=valid_freq,
        valid_interval=valid_interval,
        compile_graph=compile_graph,
        autograph=autograph,
        logging_interval=logging_interval,
        log_tag=self.name,
        max_iter=max_iter,
        terminate_on_nan=terminate_on_nan,
        callback=callback,
    )
    return self
Exemple #23
0
    def fit(
            self,
            train: Union[TensorTypes, DatasetV2],
            valid: Optional[Union[TensorTypes, DatasetV2]] = None,
            valid_freq: int = 500,
            valid_interval: float = 0,
            optimizer: Union[str, OptimizerV2] = 'adam',
            learning_rate: float = 1e-3,
            clipnorm: Optional[float] = None,
            epochs: int = -1,
            max_iter: int = 1000,
            batch_size: int = 32,
            sample_shape: List[int] = (),  # for ELBO
            analytic: Optional[bool] = None,  # for ELBO
            iw: bool = False,  # for ELBO
            callback: Callable[[], Optional[dict]] = lambda: None,
            compile_graph: bool = True,
            autograph: bool = False,
            logging_interval: float = 3,
            skip_fitted: bool = False,
            terminate_on_nan: bool = True,
            logdir: Optional[str] = None,
            allow_none_gradients: bool = False,
            track_gradients: bool = False):
        r""" Override the original fit method of keras to provide simplified
    procedure with `Networks.optimize` and
    `Networks.train_steps`

    Arguments:
      optimizer : Text, instance of `tf.optimizers.Optimizer`
        or `None`. A list of optimizers is accepted in case of multiple
        steps training.
        - If `None`, re-use stored optimizer, raise `RuntimeError` if no
          predefined optimizer found.
      callback : a Callable, called every `valid_freq` steps or
        `valid_interval` seconds
      compile_graph : a Boolean. If True, using tensorflow autograph for
        optimize function (about 2 times better speed), otherwise, run the
        function in Eager mode (better for debugging).

    """
        batch_size = int(batch_size)
        # validate the dataset
        train = _to_dataset(train, batch_size, self.dtype)
        if valid is not None:
            valid = _to_dataset(valid, batch_size, self.dtype)
        # skip training if model is fitted or reached a number of iteration
        if self.is_fitted and skip_fitted:
            if isinstance(skip_fitted, bool):
                return self
            skip_fitted = int(skip_fitted)
            if int(self.step.numpy()) >= skip_fitted:
                return self
        # create the trainer
        if self.trainer is None:
            with trackable.no_automatic_dependency_tracking_scope(self):
                trainer = Trainer(logdir=logdir)
                self.trainer = trainer
        else:
            trainer = self.trainer
        ## if already called repeat, then no need to repeat more
        if hasattr(train, 'repeat'):
            train = train.repeat(int(epochs))
        ## create the optimizer, turn off tracking so the optimizer
        # won't be saved in save_weights
        if optimizer is not None and self.optimizer is None:
            with trackable.no_automatic_dependency_tracking_scope(self):
                self.optimizer = _to_optimizer(optimizer, learning_rate,
                                               clipnorm)
        if self.optimizer is None:
            raise RuntimeError("No optimizer found!")
        ## run early stop and callback
        self.trainer.fit(
            train_ds=train,
            optimize=partial(self.optimize,
                             allow_none_gradients=allow_none_gradients,
                             track_gradients=track_gradients),
            valid_ds=valid,
            valid_freq=valid_freq,
            valid_interval=valid_interval,
            compile_graph=compile_graph,
            autograph=autograph,
            logging_interval=logging_interval,
            log_tag=self.name,
            max_iter=max_iter,
            terminate_on_nan=terminate_on_nan,
            callback=callback,
        )
        return self
Exemple #24
0
def _restore_layer_losses(losses_dict):
    for layer in losses_dict:
        with trackable.no_automatic_dependency_tracking_scope(layer):
            layer._losses = losses_dict[layer]['losses']
            layer._eager_losses = losses_dict[layer]['eager_losses']
Exemple #25
0
    def _finalize(self):
        # pylint: disable=protected-access

        # Set up call functions for all layers (skip this step for Sequential and
        # Functional models).
        for node in self._nodes:
            if isinstance(node, RevivedLayer):
                node.built = True
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if not (isinstance(node, models_lib.Sequential)
                        or is_graph_network):
                    if hasattr(node.keras_api,
                               'call_and_return_conditional_losses'):
                        node.call = utils.use_wrapped_call(
                            node,
                            node.keras_api.call_and_return_conditional_losses,
                            return_method=True)
                        node._init_call_fn_args()

        for node in self._nodes:
            if isinstance(node, RevivedNetwork):
                call_fn = node.keras_api.call_and_return_conditional_losses
                if call_fn.input_signature is None:
                    inputs = infer_inputs_from_restored_call_function(call_fn)
                else:
                    inputs = call_fn.input_signature[0]

                # Set model inputs and outputs.
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if isinstance(node, models_lib.Sequential):
                    with trackable.no_automatic_dependency_tracking_scope(
                            node):
                        node._layers = []
                    for layer in node.keras_api.layers:
                        node.add(layer)
                elif is_graph_network:
                    # Reconstruct functional model from the config and layers loaded
                    # from the SavedModel.
                    inputs, outputs, _ = network_lib.reconstruct_from_config(
                        node.get_config(),
                        created_layers={
                            layer.name: layer
                            for layer in node.layers
                        })
                    node._init_graph_network(
                        inputs,
                        outputs,
                        name=node._serialized_attributes['metadata']['name'])
                    # Set the metadata attributes once more, since _init_graph_network
                    # resets these attributes.
                    _set_network_attributes_from_metadata(node)
                else:  # Model is subclassed.
                    node._set_inputs(inputs)

            # Add unconditional losses.
            if isinstance(node, RevivedLayer):
                if hasattr(node.keras_api, 'layer_regularization_losses'):
                    losses = getattr(node.keras_api,
                                     'layer_regularization_losses', [])
                else:
                    # Some earlier SavedModels may not have layer_regularization_losses
                    # serialized separately. Fall back to using the regularization_losses
                    # list if it does not exist.
                    losses = node._serialized_attributes.get(
                        'regularization_losses', [])
                for loss in losses:
                    node.add_loss(loss)

                # Use wrapped activity regularizer function if the layer's activity
                # regularizer wasn't created during initialization.
                if node.activity_regularizer is None:
                    node.activity_regularizer = getattr(
                        node.keras_api, 'activity_regularizer_fn', None)

                # Now that the node object has been fully loaded and restored from the,
                # checkpoint, the object no longer needs to track objects added from
                # SerializedAttributes. (Note that saving a training checkpoint still
                # functions correctly, because layers and variables are tracked
                # separately by the Layer object.)
                # TODO(kathywu): Instead of outright deleting these nodes (which would
                # make restoring from a different checkpoint tricky), mark them as extra
                # dependencies that are OK to overwrite.
                for name in PUBLIC_ATTRIBUTES:
                    delete_tracking(node, name)