def _set_network_attributes_from_metadata(revived_obj): """Sets attributes recorded in the metadata.""" with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access metadata = revived_obj._serialized_attributes['metadata'] if metadata.get('dtype') is not None: revived_obj._set_dtype_policy(metadata['dtype']) revived_obj.trainable = metadata['trainable']
def _init_from_metadata(cls, metadata): """Create revived model from metadata stored in the SavedModel proto.""" revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata) with trackable.no_automatic_dependency_tracking_scope(revived_obj): revived_obj._training_config = metadata.get('training_config') # pylint:disable=protected-access return revived_obj
def _replace_child_layer_functions(layer, serialization_cache): """Replaces functions in the children layers with wrapped tf.functions. This step allows functions from parent layers to reference the wrapped functions from their children layers instead of retracing the ops. This function also resets all losses stored in the layer. These are stored in the returned dictionary. Use `_restore_child_layer_functions` to restore the original attributes. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: Dictionary mapping layer objects -> original functions and losses: { Child layer 1: { 'losses': Original losses, 'call': Original call function 'activity_regularizer': Original activity regularizer}, Child layer 2: ... } """ # pylint: disable=protected-access original_fns = {} for child_layer in _list_all_layers(layer): if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: layer_fns = ( child_layer._trackable_saved_model_saver. _get_serialized_attributes(serialization_cache).functions) else: layer_fns = (serialization_cache[constants.KERAS_CACHE_KEY] [child_layer].functions) if not layer_fns: # This indicates either: # - circular dependency, which means the current layer's functions # should be wrapped first. # - Child layer's inputs are not defined, so its functions have not been # wrapped. In this case, no replacement is necessary so move on to the # next child. continue original_fns[child_layer] = { 'call': child_layer.call, 'activity_regularizer': child_layer.activity_regularizer } with trackable.no_automatic_dependency_tracking_scope(child_layer): try: child_layer.activity_regularizer = layer_fns.get( 'activity_regularizer_fn') except AttributeError: # Some layers have an unsettable activity regularizer. pass child_layer.call = utils.use_wrapped_call( child_layer, layer_fns['call_and_return_conditional_losses'], default_training_value=False) return original_fns
def _restore_child_layer_functions(original_fns): """Restores attributes replaced with `_replace_child_layer_functions`.""" for child_layer, fns in original_fns.items(): with trackable.no_automatic_dependency_tracking_scope(child_layer): for fn_name, fn in fns.items(): try: setattr(child_layer, fn_name, fn) # pylint: disable=protected-access except AttributeError: pass # In the case of _activity_regularizer, setting the attribute
def _restore_child_layer_functions(original_fns): """Restores attributes replaced with `_replace_child_layer_functions`.""" for child_layer, fns in original_fns.items(): with trackable.no_automatic_dependency_tracking_scope(child_layer): child_layer.call = fns['call'] try: child_layer.activity_regularizer = fns['activity_regularizer'] except AttributeError: pass
def _reset_layer_losses(parent_layer): """Resets losses of layer and its sublayers, and returns original losses.""" losses_dict = {} for layer in _list_all_layers(parent_layer) + [parent_layer]: losses_dict[layer] = {'losses': layer._losses[:], 'eager_losses': layer._eager_losses[:]} with trackable.no_automatic_dependency_tracking_scope(layer): layer._losses = [] layer._eager_losses = [] return losses_dict
def main(_): tr1 = base.Trackable() v = tf.Variable(1) tr1._track_trackable(v, name='tr1_v') for _ in range(3): trackable(tr1, v) tr2 = tracking.AutoTrackable() tracked, untracked = tf.Variable(1000), tf.Variable(0) tr2.v = tracked with base.no_automatic_dependency_tracking_scope(tr2): tr2.untracked = untracked for _ in range(2): autotrackable(tr2, tracked, untracked) listing() deleting(tr2) tr3 = tracking.AutoTrackable() br1 = tracking.AutoTrackable() br1.v = tf.Variable(5) br2 = tracking.AutoTrackable() br2.v = tf.Variable(5) tr3.br_list = [br1, br2] br3 = tracking.AutoTrackable() br3.v = tf.Variable(5) tr3.br_dict = {'br3': br3} containers(tr3) tr3.br_dict = {'br1': br1, 'br2': br2, 'br3': br3} sharing(tr3) mod1 = Module('m1') mod1.sub = Module('m2') mod1.sub.sub = Module('m3') modules(mod1) # @tf.function # def tracer1(): # return mod1() # graph(tracer1) ins = [tf.keras.Input(shape=(), dtype=tf.int32)] lay = Layer(name='l1', sub=Layer(name='l2', sub=Layer(name='l3'))) outs = [lay(ins)] mod2 = tf.keras.Model(name='m2', inputs=ins, outputs=outs) models(mod2, lay) @tf.function def tracer2(): return mod2(tf.constant([100, 100])) graph(tracer2)
def replace_metric_functions(child_layer, serialized_fns): """Replaces metric functions with wrapped functions.""" original_fns[child_layer] = { '__call__': child_layer.__call__, 'result': child_layer.result, 'update_state': child_layer.update_state } with trackable.no_automatic_dependency_tracking_scope(child_layer): child_layer.__call__ = serialized_fns['__call__'] child_layer.result = serialized_fns['result'] child_layer.update_state = serialized_fns['update_state']
def _init_from_metadata(cls, metadata): """Revives the saved InputLayer from the Metadata.""" init_args = dict(name=metadata['name'], dtype=metadata['dtype'], sparse=metadata['sparse'], ragged=metadata['ragged'], batch_input_shape=metadata['batch_input_shape']) revived_obj = cls(**init_args) with trackable.no_automatic_dependency_tracking_scope(revived_obj): revived_obj._config = metadata['config'] # pylint:disable=protected-access return revived_obj, setattr
def _capture_init(self, *args, **kwargs): """Captures init args and kwargs and stores them into `_saved_kwargs`.""" if len(args) > len(arg_spec.args) + 1: # Error case: more inputs than args. Call init so that the appropriate # error can be raised to the user. init(self, *args, **kwargs) # Convert to a canonical kwarg format. kwargs = tf_inspect.getcallargs(init, self, *args, **kwargs) kwargs.pop("self") init(self, **kwargs) # Avoid auto tracking which prevents keras from tracking layers that are # passed as kwargs to the Network. with base.no_automatic_dependency_tracking_scope(self): setattr(self, "_saved_kwargs", kwargs)
def _capture_init(self, *args, **kwargs): """Captures init args and kwargs and stores them into `_saved_kwargs`.""" if len(args) > len(arg_spec.args) + 1: # Error case: more inputs than args. Call init so that the appropriate # error can be raised to the user. init(self, *args, **kwargs) for i, arg in enumerate(args): # Add +1 to skip `self` in arg_spec.args. kwargs[arg_spec.args[1 + i]] = arg init(self, **kwargs) # Avoid auto tracking which prevents keras from tracking layers that are # passed as kwargs to the Network. with base.no_automatic_dependency_tracking_scope(self): setattr(self, "_saved_kwargs", kwargs)
def _init_from_metadata(cls, metadata): """Create revived network from metadata stored in the SavedModel proto.""" revived_obj = cls(name=metadata['name']) # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._serialized_attributes = {'metadata': metadata} _set_network_attributes_from_metadata(revived_obj) # pylint:enable=protected-access return revived_obj, _revive_setter # pylint:disable=protected-access
def _set_network_attributes_from_metadata(revived_obj): """Sets attributes recorded in the metadata.""" with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access metadata = revived_obj._serialized_attributes['metadata'] if metadata.get('dtype') is not None: revived_obj._dtype = metadata['dtype'] revived_obj.trainable = metadata['trainable'] revived_obj._expects_training_arg = metadata['expects_training_arg'] if metadata.get('config') is not None: revived_obj._config = metadata['config'] if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer'])
def replace_layer_functions(child_layer, serialized_fns): """Replaces layer call and activity regularizer with wrapped functions.""" original_fns[child_layer] = { 'call': child_layer.call, '_activity_regularizer': child_layer._activity_regularizer } with trackable.no_automatic_dependency_tracking_scope(child_layer): try: child_layer._activity_regularizer = serialized_fns.get( 'activity_regularizer_fn') except AttributeError: # Some layers have an unsettable activity regularizer. pass child_layer.call = utils.use_wrapped_call( child_layer, serialized_fns['call_and_return_conditional_losses'], default_training_value=False)
def call(self, inputs, training=None, sample_shape=(), projection=None, **kwargs): ## NOTE: a 2D inputs is important here, but we don't want to flatten # automatically if self.flatten_inputs: inputs = tf.reshape(inputs, (tf.shape(inputs)[0], -1)) params = inputs ## do not use tf.cond here, it infer the wrong shape when # trying to build the layer in Graph mode. projection = projection if projection is not None else self.projection if projection: params = self._dense(params) if self.autoregressive: params = tf.concat(tf.unstack(params, axis=-1), axis=-1) ## applying dropout if self._dropout > 0: params = bk.dropout(params, p_drop=self._dropout, training=training) ## create posterior distribution self._posterior_sample_shape = sample_shape kw = dict() if 'training' in self._posterior_call_kw: kw['training'] = training if 'sample_shape' in self._posterior_call_kw: kw['sample_shape'] = sample_shape for k, v in kwargs.items(): if k in self._posterior_call_kw: kw[k] = v posterior = self.posterior_layer(params, **kw) # tensorflow tries to serialize the distribution, which raise exception # when saving the graphs, to avoid this, store it as non-tracking list. with trackable.no_automatic_dependency_tracking_scope(self): # self._no_dependency self._most_recently_built_distribution = posterior ## NOTE: all distribution has the method kl_divergence, so we cannot use it posterior.KL_divergence = KLdivergence( posterior, prior=self.prior, sample_shape=None) # None mean reuse sampled data here return posterior
def _init_from_metadata(cls, metadata): """Create revived network from metadata stored in the SavedModel proto.""" revived_obj = cls(name=metadata['name']) # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) # pylint:enable=protected-access return revived_obj, _revive_setter # pylint:disable=protected-access
def __new__(cls, *args, **kwargs): class_tree = [c for c in type.mro(cls) if issubclass(c, keras.Model)][::-1] # get default arguments from parents classes kw = dict() for c in class_tree: spec = inspect.getfullargspec(c.__init__) if spec.defaults is not None: for key, val in zip(spec.args[::-1], spec.defaults[::-1]): kw[key] = val # update the user provided arguments for k, v in zip(spec.args[1:], args): kw[k] = v kw.update(kwargs) # deep copy is necessary here otherwise the init function will modify # the arguments kw = copy.copy(kw) # create the instance instance = super().__new__(cls, *args, **kwargs) # must make _init_args NonDependency (i.e. nontrackable and won't be # saved in save_weights) with trackable.no_automatic_dependency_tracking_scope(instance): instance._init_args = kw return instance
def _add_serialized_attributes(layer, metadata): # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. with trackable.no_automatic_dependency_tracking_scope(layer): layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
def wrapper(self, *args, **kwargs): self._init_args = args with base.no_automatic_dependency_tracking_scope(self): setattr(self, "_init_kwargs", kwargs) return func(self, *args, **kwargs)
def build(self, input_shape=None): """ Create any Variables used in the model. Parameters ---------- input_shape : list of tuple of int Shapes of all the inputs to this layer. """ super().build(input_shape) tf.random.set_seed(self.seed) def get_initializer(init_vals): """Use more efficient initializers if possible to save memory.""" values, shapes, dtype, minibatched = init_vals # initial value of None means that the initial value isn't used, so we # can use anything for the initial value if all(v is None for v in values): initializer = None elif all(v is None or np.all(v == 0) for v in values): initializer = tf.initializers.zeros() elif all(v is None or np.all(v == 1) for v in values): initializer = tf.initializers.ones() else: val = tf.concat( [ tf.zeros(s, dtype) if v is None else tf.cast( tf.broadcast_to(v, s), dtype) for v, s in zip(values, shapes) ], axis=1 if minibatched else 0, ) initializer = lambda shape=None, dtype=None: val # figure out shape of full concatenated initial value shape = list(shapes[0]) shape[minibatched] = sum(x[minibatched] for x in shapes) return initializer, tuple(shape), dtype # variables for model parameters with trackable.no_automatic_dependency_tracking_scope(self): self.base_params = OrderedDict() assert len(self.base_params) == 0 for sig_type in ("trainable", "non_trainable"): for k, v in self.base_arrays_init[sig_type].items(): initializer, shape, dtype = get_initializer(v) assert initializer is not None # params should never be set self.base_params[k] = self.add_weight( initializer=initializer, shape=shape, dtype=dtype, trainable=sig_type == "trainable", name="base_params/%s_%s_%s" % (sig_type, dtype, "_".join(str(x) for x in shape)), ) logger.debug("created base param variables") logger.debug([str(x) for x in self.base_params.values()]) # variables to save the internal state of simulation between runs with trackable.no_automatic_dependency_tracking_scope(self): self.saved_state = OrderedDict() for k, v in self.base_arrays_init["state"].items(): initializer, shape, dtype = get_initializer(v) if initializer is not None: # don't need to save the state for signals where the initial value # doesn't matter self.saved_state[k] = tf.Variable( initial_value=lambda: initializer(shape=shape, dtype=dtype ), shape=shape, dtype=dtype, trainable=False, name="saved_state/%s_%s" % (dtype, "_".join(str(x) for x in shape)), ) logger.debug("created saved state variables") logger.debug([str(x) for x in self.saved_state.values()]) # call build on any TensorNode Layers def unbuild(layer): assert layer.built # clear any losses attached to layer (they will be recreated in the # build step, so we don't want to keep around any losses # associated with the previous build) # note: not clearing layer._losses, because those are manually added # by the user (not created during the build process) layer._eager_losses = [] layer._callable_losses = [] layer.built = False for sub in layer._layers: if isinstance(sub, tf.keras.layers.Layer): unbuild(sub) layer_ops = [ op for ops in self.plan if isinstance(ops[0], tensor_node.SimTensorNode) for op in ops if isinstance(op.func, tf.keras.layers.Layer) ] weight_gets = [] weight_sets = [] for op in layer_ops: if op.func in self._layers: # already built this layer continue if op.time is None: shape_in = [] else: shape_in = [()] if op.input is not None: shape_in += [(self.minibatch_size, ) + op.shape_in] if len(shape_in) == 1: shape_in = shape_in[0] if op.func.built: # we rebuild the layer (even if it is already built), # because we need to build the weights within the TensorGraph # context # save the weight values so they can be restored # exactly inside the tensornode weights = op.func.weights weight_gets.extend(weights) # clear the results of previous build unbuild(op.func) else: weights = None with tf.name_scope(op.func.name): op.func.build(shape_in) if weights is not None: weight_sets.extend(op.func.weights) # add op func to _layers so that any weights are collected self._layers.append(op.func) if len(weight_gets) > 0: # do all the weight getting/setting in one go, for efficiency reasons ctx = (weight_gets[0].graph.as_default() if hasattr( weight_gets[0], "graph") else context.eager_mode()) with ctx: weight_vals = tf.keras.backend.batch_get_value(weight_gets) tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals)) # initialize state variables (need to do this manually because we're not # adding them to self.weights) # note: don't need to do this in eager mode, since variables are # initialized on creation # TODO: why does this cause problems if it is done before the tensornode # weight get/sets above? if not context.executing_eagerly(): tf.keras.backend.batch_get_value( [var.initializer for var in self.saved_state.values()])
def call(self, inputs, training=None, progress=None, stateful=False): """ Constructs the graph elements to simulate the model. Parameters ---------- inputs : list of ``tf.Tensor`` Input layers/tensors for the network (must match the structure defined in `.build_inputs`). training : bool Whether the network is being run in training or inference mode. If None, uses the symbolic Keras learning phase variable. progress : `.utils.ProgressBar` Progress bar for construction stage. stateful : bool Whether or not to build the model to support preserving the internal state between executions. Returns ------- probe_arrays : list of ``tf.Tensor`` Tensors representing the output of all the Probes in the network (order corresponding to ``self.model.probes``, which is the order the Probes were instantiated). """ super().call(inputs, training=training) if training == 1 and self.inference_only: raise BuildError( "TensorGraph was created with inference_only=True; cannot be built " "with training=%s" % training) tf.random.set_seed(self.seed) if progress is None: progress = utils.NullProgressBar() # reset signaldict self.signals.reset() # create these constants once here for reuse in different operators self.signals.dt = tf.constant(self.dt, self.dtype) self.signals.dt_val = self.dt # store the actual value as well self.signals.zero = tf.constant(0, self.dtype) self.signals.one = tf.constant(1, self.dtype) # set up invariant inputs with trackable.no_automatic_dependency_tracking_scope(self): self.node_inputs = {} for n, inp in zip(self.invariant_inputs, inputs): # specify shape of inputs (keras sometimes loses this shape information) inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out]) self.node_inputs[n] = inp self.steps_to_run = inputs[-1][0, 0] # initialize op builder build_config = builder.BuildConfig( inference_only=self.inference_only, lif_smoothing=config.get_setting(self.model, "lif_smoothing"), cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed, rng=np.random.RandomState(self.seed), training=(tf.keras.backend.learning_phase() if training is None else training), ) self.op_builder = builder.Builder(self.plan, self.signals, build_config) # pre-build stage with progress.sub("pre-build stage", max_value=len(self.plan)) as sub: self.op_builder.build_pre(sub) # build stage with progress.sub("build stage", max_value=len(self.plan) * self.unroll) as sub: steps_run, probe_arrays, final_internal_state, final_base_params = ( self._build_loop(sub) if self.use_loop else self._build_no_loop(sub)) # store these so that they can be accessed after the initial build with trackable.no_automatic_dependency_tracking_scope(self): self.steps_run = steps_run self.probe_arrays = probe_arrays self.final_internal_state = final_internal_state self.final_base_params = final_base_params # logging logger.info("Number of reads: %d", sum(x for x in self.signals.read_types.values())) for x in self.signals.read_types.items(): logger.info(" %s: %d", *x) logger.info("Number of writes: %d", sum(x for x in self.signals.write_types.values())) for x in self.signals.write_types.items(): logger.info(" %s: %d", *x) # note: always return steps_run so that the simulation will run for the given # number of steps, even if there are no output probes outputs = list(probe_arrays.values()) + [steps_run] updates = [] if stateful: # update saved state updates.extend( var.assign(val) for var, val in zip(self.saved_state.values(), final_internal_state)) # if any of the base params have changed (due to online learning rules) then we # also need to assign those back to the original variable (so that their # values will persist). any parameters targeted by online learning rules # will be minibatched, so we only need to update the minibatched params. for (key, var), val in zip(self.base_params.items(), final_base_params): try: minibatched = self.base_arrays_init["non_trainable"][key][-1] except KeyError: minibatched = self.base_arrays_init["trainable"][key][-1] if minibatched: updates.append(var.assign(val)) logger.info("Number of variable updates: %d", len(updates)) if len(updates) > 0: with tf.control_dependencies(updates): outputs = [tf.identity(x) for x in outputs] return outputs
def fit(self, train: Union[TensorTypes, DatasetV2], valid: Optional[Union[TensorTypes, DatasetV2]] = None, valid_freq: int = 500, valid_interval: float = 0, optimizer: Union[str, List[str], OptimizerV2, List[OptimizerV2]] = 'adam', learning_rate: float = 1e-3, clipnorm: Optional[float] = None, epochs: int = -1, max_iter: int = 1000, batch_size: int = 32, callback: Union[Callback, List[Callback]] = lambda: None, compile_graph: bool = True, autograph: bool = False, logging_interval: float = 3, skip_fitted: Union[bool, int] = False, terminate_on_nan: bool = True, logdir: Optional[str] = None, allow_none_gradients: bool = False, track_gradients: bool = False) -> Networks: """Override the original fit method of keras to provide simplified procedure with `Networks.optimize` and `Networks.train_steps` Parameters ---------- train : Union[TensorTypes, DatasetV2] tensorflow Dataset for training valid : Optional[Union[TensorTypes, DatasetV2]], optional tensorflow Dataset for validation, by default None valid_freq : int, optional the frequency, in steps, for performing validation, by default 500 valid_interval : float, optional the interval, in second, for performing validation, by default 0 optimizer : Union[str, OptimizerV2], optional A list of optimizers is accepted in case of multiple steps training. If `None`, re-use stored optimizer, raise `RuntimeError` if no predefined optimizer found., by default 'adam' learning_rate : float, optional learning rate for initializing the optimizer, by default 1e-3 clipnorm : Optional[float], optional global L2-norm value for clipping the gradients, by default None epochs : int, optional maximum number of epochs, by default -1 max_iter : int, optional maximum number of iteration, by default 1000 batch_size : int, optional number of examples for mini-batch, by default 32 callback : Union[Callback, List[Callback]], optional a function or list of functions called every `valid_freq` steps or `valid_interval` seconds, by default lambda:None compile_graph : bool, optional If True, using tensorflow autograph for optimize function (about 2 times speed gain), otherwise, run the function in Eager mode (better for debugging), by default True autograph : bool, optional use autograph to compile the function, by default False logging_interval : float, optional interval, in seconds, for printing logging information, by default 3 skip_fitted : Union[bool, int], optional skip this function if the model if fitted, or fitted for certain amount of steps, by default False terminate_on_nan : bool, optional terminate the training if NaNs returned, by default True logdir : Optional[str], optional tensorboard logging directory, by default None allow_none_gradients : bool, optional allow variables with None gradients during training, by default False track_gradients : bool, optional track and return the metrics includes the gradients' L2-norm for each trainable variable, by default False Returns ------- Networks the network itself for method chaining Raises ------ RuntimeError if the optimizer is not defined. """ if not self.built: raise RuntimeError( "build(input_shape) method must be called to initialize " "the variables before calling fit") batch_size = int(batch_size) # validate the dataset train = _to_dataset(train, batch_size, self.dtype) if valid is not None: valid = _to_dataset(valid, batch_size, self.dtype) # skip training if model is fitted or reached a number of iteration if self.is_fitted and skip_fitted: if isinstance(skip_fitted, bool): return self skip_fitted = int(skip_fitted) if int(self.step.numpy()) >= skip_fitted: return self # create the trainer if self.trainer is None: with trackable.no_automatic_dependency_tracking_scope(self): trainer = Trainer(logdir=logdir) self.trainer = trainer else: trainer = self.trainer ## if already called repeat, then no need to repeat more if hasattr(train, 'repeat'): train = train.repeat(int(epochs)) ## create the optimizer, turn off tracking so the optimizer # won't be saved in save_weights if optimizer is not None and self.optimizer is None: with trackable.no_automatic_dependency_tracking_scope(self): self.optimizer = _to_optimizer(optimizer, learning_rate, clipnorm) if self.optimizer is None: raise RuntimeError("No optimizer found!") ## run early stop and callback self.trainer.fit( train_ds=train, optimize=partial(self.optimize, allow_none_gradients=allow_none_gradients, track_gradients=track_gradients), valid_ds=valid, valid_freq=valid_freq, valid_interval=valid_interval, compile_graph=compile_graph, autograph=autograph, logging_interval=logging_interval, log_tag=self.name, max_iter=max_iter, terminate_on_nan=terminate_on_nan, callback=callback, ) return self
def fit( self, train: Union[TensorTypes, DatasetV2], valid: Optional[Union[TensorTypes, DatasetV2]] = None, valid_freq: int = 500, valid_interval: float = 0, optimizer: Union[str, OptimizerV2] = 'adam', learning_rate: float = 1e-3, clipnorm: Optional[float] = None, epochs: int = -1, max_iter: int = 1000, batch_size: int = 32, sample_shape: List[int] = (), # for ELBO analytic: Optional[bool] = None, # for ELBO iw: bool = False, # for ELBO callback: Callable[[], Optional[dict]] = lambda: None, compile_graph: bool = True, autograph: bool = False, logging_interval: float = 3, skip_fitted: bool = False, terminate_on_nan: bool = True, logdir: Optional[str] = None, allow_none_gradients: bool = False, track_gradients: bool = False): r""" Override the original fit method of keras to provide simplified procedure with `Networks.optimize` and `Networks.train_steps` Arguments: optimizer : Text, instance of `tf.optimizers.Optimizer` or `None`. A list of optimizers is accepted in case of multiple steps training. - If `None`, re-use stored optimizer, raise `RuntimeError` if no predefined optimizer found. callback : a Callable, called every `valid_freq` steps or `valid_interval` seconds compile_graph : a Boolean. If True, using tensorflow autograph for optimize function (about 2 times better speed), otherwise, run the function in Eager mode (better for debugging). """ batch_size = int(batch_size) # validate the dataset train = _to_dataset(train, batch_size, self.dtype) if valid is not None: valid = _to_dataset(valid, batch_size, self.dtype) # skip training if model is fitted or reached a number of iteration if self.is_fitted and skip_fitted: if isinstance(skip_fitted, bool): return self skip_fitted = int(skip_fitted) if int(self.step.numpy()) >= skip_fitted: return self # create the trainer if self.trainer is None: with trackable.no_automatic_dependency_tracking_scope(self): trainer = Trainer(logdir=logdir) self.trainer = trainer else: trainer = self.trainer ## if already called repeat, then no need to repeat more if hasattr(train, 'repeat'): train = train.repeat(int(epochs)) ## create the optimizer, turn off tracking so the optimizer # won't be saved in save_weights if optimizer is not None and self.optimizer is None: with trackable.no_automatic_dependency_tracking_scope(self): self.optimizer = _to_optimizer(optimizer, learning_rate, clipnorm) if self.optimizer is None: raise RuntimeError("No optimizer found!") ## run early stop and callback self.trainer.fit( train_ds=train, optimize=partial(self.optimize, allow_none_gradients=allow_none_gradients, track_gradients=track_gradients), valid_ds=valid, valid_freq=valid_freq, valid_interval=valid_interval, compile_graph=compile_graph, autograph=autograph, logging_interval=logging_interval, log_tag=self.name, max_iter=max_iter, terminate_on_nan=terminate_on_nan, callback=callback, ) return self
def _restore_layer_losses(losses_dict): for layer in losses_dict: with trackable.no_automatic_dependency_tracking_scope(layer): layer._losses = losses_dict[layer]['losses'] layer._eager_losses = losses_dict[layer]['eager_losses']
def _finalize(self): # pylint: disable=protected-access # Set up call functions for all layers (skip this step for Sequential and # Functional models). for node in self._nodes: if isinstance(node, RevivedLayer): node.built = True is_graph_network = node._serialized_attributes['metadata'].get( 'is_graph_network', False) if not (isinstance(node, models_lib.Sequential) or is_graph_network): if hasattr(node.keras_api, 'call_and_return_conditional_losses'): node.call = utils.use_wrapped_call( node, node.keras_api.call_and_return_conditional_losses, return_method=True) node._init_call_fn_args() for node in self._nodes: if isinstance(node, RevivedNetwork): call_fn = node.keras_api.call_and_return_conditional_losses if call_fn.input_signature is None: inputs = infer_inputs_from_restored_call_function(call_fn) else: inputs = call_fn.input_signature[0] # Set model inputs and outputs. is_graph_network = node._serialized_attributes['metadata'].get( 'is_graph_network', False) if isinstance(node, models_lib.Sequential): with trackable.no_automatic_dependency_tracking_scope( node): node._layers = [] for layer in node.keras_api.layers: node.add(layer) elif is_graph_network: # Reconstruct functional model from the config and layers loaded # from the SavedModel. inputs, outputs, _ = network_lib.reconstruct_from_config( node.get_config(), created_layers={ layer.name: layer for layer in node.layers }) node._init_graph_network( inputs, outputs, name=node._serialized_attributes['metadata']['name']) # Set the metadata attributes once more, since _init_graph_network # resets these attributes. _set_network_attributes_from_metadata(node) else: # Model is subclassed. node._set_inputs(inputs) # Add unconditional losses. if isinstance(node, RevivedLayer): if hasattr(node.keras_api, 'layer_regularization_losses'): losses = getattr(node.keras_api, 'layer_regularization_losses', []) else: # Some earlier SavedModels may not have layer_regularization_losses # serialized separately. Fall back to using the regularization_losses # list if it does not exist. losses = node._serialized_attributes.get( 'regularization_losses', []) for loss in losses: node.add_loss(loss) # Use wrapped activity regularizer function if the layer's activity # regularizer wasn't created during initialization. if node.activity_regularizer is None: node.activity_regularizer = getattr( node.keras_api, 'activity_regularizer_fn', None) # Now that the node object has been fully loaded and restored from the, # checkpoint, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked # separately by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)