def get(base_dir, cfg, model, train_steps, **params): callbacks = [ K.callbacks.TerminateOnNaN(), ] if cfg.ckpt: callbacks.append( K.callbacks.ModelCheckpoint(base_dir + f"/{cfg.tag}/ckpt/{cfg.ckpt}", save_weights_only=True, verbose=1)) if cfg.best_ckpt: callbacks.append( K.callbacks.ModelCheckpoint(base_dir + f"/{cfg.tag}/ckpt/{cfg.best_ckpt}", save_best_only=True, save_weights_only=True, verbose=1)) if cfg.tensorboard: callbacks.append( K.callbacks.TensorBoard(base_dir + f"/{cfg.tag}/{cfg.tensorboard}", write_graph=False)) if cfg.lrp: from . import optimizer callbacks.append(optimizer.lr_callback(cfg)) final_params = { "verbose": cfg.verbose, "epochs": cfg.total_epochs, "steps": train_steps } return callbacks_module.CallbackList(callbacks, add_history=True, add_progbar=cfg.verbose != 0, model=model, **params)
def fit( self, x=None, y=None, validation_data=None, epochs=1, verbose=0, callbacks=None, **kwargs, ): """ If the optimizer is genetic, the fitting procedure consists on executing `run_step` for the given number of epochs. """ if self.is_genetic: # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=False, model=self, verbose=verbose, epochs=epochs, ) callbacks.on_train_begin() result = self.perform_genetic_fit( x=x, y=y, epochs=epochs, verbose=verbose, validation_data=validation_data, callbacks=callbacks, ) else: result = super().fit( x=x, y=y, validation_data=validation_data, epochs=epochs, verbose=verbose, callbacks=callbacks, **kwargs, ) return result
def accuracy_aware_fit(cls_instance, train_dataset, compression_ctrl, nncf_config, callbacks, initial_epoch, uncompressed_model_accuracy, steps_per_epoch=None, batch_size=None, tensorboard_writer=None, log_dir=None, validation_data=None, validation_steps=None, result_dict_to_val_metric_fn=None, **kwargs): if result_dict_to_val_metric_fn is None: result_dict_to_val_metric_fn = lambda metric: metric with cls_instance.distribute_strategy.scope(), \ training_utils.RespectCompiledTrainableState(cls_instance): # pylint: disable=protected-access data_handler = data_adapter.DataHandler( x=train_dataset, y=None, sample_weight=None, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=1, shuffle=True, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, model=cls_instance, steps_per_execution=cls_instance._steps_per_execution) if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, model=cls_instance, epochs=1, verbose=1, add_progbar=True, steps=data_handler.inferred_steps) def train_epoch_fn(compression_ctrl, model, epoch): model.reset_metrics() if model.train_function is None: model.train_function = model.make_train_function() _, iterator = next(data_handler.enumerate_epochs()) callbacks.on_epoch_begin(epoch) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with trace.Trace('train', epoch_num=epoch, step_num=step, batch_size=None, _r=1): callbacks.on_train_batch_begin(step) tmp_logs = model.train_function(iterator) if data_handler.should_sync: context.async_wait() logs = tmp_logs end_step = step + data_handler.step_increment callbacks.on_train_batch_end(end_step, logs) if model.stop_training: break if logs is None: raise ValueError('Expect x to be a non-empty array or dataset.') epoch_logs = copy.copy(logs) callbacks.on_epoch_end(epoch, epoch_logs) if validation_data is None: validation_data = train_dataset def validate_fn(model, epoch=None): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = model.evaluate(x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=None, steps=validation_steps, callbacks=callbacks, return_dict=True) return result_dict_to_val_metric_fn(val_logs) callbacks.on_train_begin() cls_instance.original_model_accuracy = uncompressed_model_accuracy acc_aware_training_loop = create_accuracy_aware_training_loop( nncf_config, compression_ctrl) cls_instance = acc_aware_training_loop.run( cls_instance, train_epoch_fn=train_epoch_fn, validate_fn=validate_fn, tensorboard_writer=tensorboard_writer, log_dir=log_dir) callbacks.on_train_end()
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False): """ From tf.keras.Model. """ training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') training._disallow_inside_tf_function('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split)) if validation_data: val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) with self.distribute_strategy.scope(), \ training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = data_adapter.DataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() self._train_counter.assign(0) callbacks.on_train_begin() training_logs = None # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( # pylint: disable=protected-access self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) with data_handler.catch_stop_iteration(): if self._update_cycle > 1: self._grad_accumulator.reset() for step in data_handler.steps(): with trace.Trace( 'TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) if self._update_cycle > 1: for _ in range(self._update_cycle - 1): self.accumulate_function(iterator) tmp_logs = train_function(iterator) if data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. end_step = step + data_handler.step_increment callbacks.on_train_batch_end(end_step, logs) epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval(epoch, validation_freq): # Create data_handler for evaluation and cache it. if getattr(self, '_eval_data_handler', None) is None: self._eval_data_handler = data_adapter.DataHandler( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps_per_epoch=validation_steps, initial_epoch=0, epochs=1, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = {'val_' + name: val for name, val in val_logs.items()} epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) training_logs = epoch_logs if self.stop_training: break # If eval data_hanlder exists, delete it after all epochs are done. if getattr(self, '_eval_data_handler', None) is not None: del self._eval_data_handler callbacks.on_train_end(logs=training_logs) return self.history
def build_callbacks(self, conf, callbacks_list): ''' The purpose of the method is to set up logging and history. It is based on Keras Callbacks https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler Argument list: - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are: - list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg) - metrics: List of quantities monitored during training and validation - mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. - monitor: Quantity used for early stopping, has to be from the list of metrics - patience: Number of epochs used to decide on whether to apply early stopping or continue training - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks Returns: modified list of callbacks ''' mode = conf['callbacks']['mode'] monitor = conf['callbacks']['monitor'] patience = conf['callbacks']['patience'] csvlog_save_path = conf['paths']['csvlog_save_path'] # CSV callback is on by default if not os.path.exists(csvlog_save_path): os.makedirs(csvlog_save_path) callbacks_list = conf['callbacks']['list'] callbacks = [cbks.BaseLogger()] callbacks += [self.history] callbacks += [ cbks.CSVLogger("{}callbacks-{}.log".format( csvlog_save_path, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))) ] if "earlystop" in callbacks_list: callbacks += [ cbks.EarlyStopping(patience=patience, monitor=monitor, mode=mode) ] if "lr_scheduler" in callbacks_list: pass return cbks.CallbackList(callbacks)
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, dynamic_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False): training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched = True weights = backend.batch_get_value(self.trainable_variables) while switched: self.initialize_epoch(epoch) iterator = iter(dataset) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, dynamic_switch, verbose) # If a switch occurred, we need to restore the weights if switched: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) if self.accumulate_gradients: self.optimizer.apply_gradients( zip(self.accumulated_gradients, self.trainable_variables)) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history
def train(self, train_data, val_data=None, **kwargs): cache = self.cache cfg = self.cfg.train cfg.merge_from_dict(kwargs) ckpt_cfg = cfg.ModelCheckpoint es_cfg = cfg.EarlyStopping pb_cfg = cfg.Progbar model = self.model if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `trainer.build()`.' ) if not isinstance(train_data, Sequence): train_data = self.train_sequence(train_data) cache.train_data = train_data validation = val_data is not None if validation: if not isinstance(val_data, Sequence): val_data = self.test_sequence(val_data) cache.val_data = val_data elif ckpt_cfg.enabled and ckpt_cfg.monitor.startswith("val_"): ckpt_cfg.monitor = ckpt_cfg.monitor[4:] warnings.warn( f"The metric 'val_{ckpt_cfg.monitor}' is invalid without validation " f"and has been automatically replaced with '{ckpt_cfg.monitor}'.", UserWarning) callbacks = callbacks_module.CallbackList() history = History() callbacks.append(history) if es_cfg.enabled: assert es_cfg.monitor.startswith("val") es_callback = EarlyStopping(monitor=es_cfg.monitor, patience=es_cfg.monitor, mode=es_cfg.mode, verbose=es_cfg.verbose) callbacks.append(es_callback) if ckpt_cfg.enabled: if not ckpt_cfg.path.endswith(gg.file_ext()): ckpt_cfg.path += gg.file_ext() makedirs_from_filepath(ckpt_cfg.path) mc_callback = ModelCheckpoint( ckpt_cfg.path, monitor=ckpt_cfg.monitor, save_best_only=ckpt_cfg.save_best_only, save_weights_only=ckpt_cfg.save_weights_only, verbose=ckpt_cfg.vervose) callbacks.append(mc_callback) callbacks.set_model(model) model.stop_training = False verbose = cfg.verbose if verbose: if verbose <= 2: progbar = Progbar(target=cfg.epochs, width=pb_cfg.width, verbose=verbose) print("Training...") logs = gf.BunchDict() callbacks.on_train_begin() try: for epoch in range(cfg.epochs): if verbose > 2: progbar = Progbar(target=len(train_data), width=pb_cfg.width, verbose=verbose - 2) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) train_logs = self.train_step(train_data) train_data.on_epoch_end() logs.update(train_logs) if validation: valid_logs = self.test_step(val_data) logs.update({("val_" + k): v for k, v in valid_logs.items()}) val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), logs) callbacks.on_epoch_end(epoch, logs) if verbose > 2: print(f"Epoch {epoch+1}/{epochs}") progbar.update(len(train_data), logs.items()) elif verbose: progbar.update(epoch + 1, logs.items()) if model.stop_training: print(f"Early Stopping at Epoch {epoch}", file=sys.stderr) break callbacks.on_train_end() if ckpt_cfg.enabled: if ckpt_cfg.save_weights_only: model.load_weights(ckpt_cfg.path) else: self.model = model.load(ckpt_cfg.path) finally: # to avoid unexpected termination of the model if ckpt_cfg.enabled and ckpt_cfg.remove_weights: self.remove_weights() return history
def train(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=0, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: int in {0, 1, 2} 'verbose=0': not verbose; 'verbose=1': tqdm verbose; 'verbose=2': tensorflow probar verbose; (default :obj: 0) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using customized `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ if not verbose in {0, 1, 2}: raise ValueError( "'verbose=0': not verbose; 'verbose=1': tqdm verbose; " "'verbose=2': tensorflow probar verbose; " f"but got {verbose}") model = self.model # Check if model has been built if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) # TODO: add metric names in `model` metric_names = ['loss', 'acc'] callback_metrics = metric_names model.stop_training = False if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val callback_metrics = copy.copy(metric_names) callback_metrics += ['val_' + n for n in metric_names] else: monitor = 'acc' if monitor[:3] == 'val' else monitor if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks) history = tf_History() callbacks.append(history) if verbose == 2: callbacks.append(ProgbarLogger(stateful_metrics=metric_names[1:])) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 1)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path makedirs_from_path(weight_path) if not weight_path.endswith('.h5'): weight_path = weight_path + '.h5' mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) # TODO: to be improved callback_params = { 'batch_size': None, 'epochs': epochs, 'steps': 1, 'samples': 1, 'verbose': verbose == 2, 'do_validation': validation, 'metrics': callback_metrics, } callbacks.set_params(callback_params) raise_if_kwargs(kwargs) callbacks.on_train_begin() if verbose == 1: pbar = tqdm(range(1, epochs + 1)) else: pbar = range(epochs) for epoch in pbar: callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) training_logs = {'loss': loss, 'acc': accuracy} if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_train_batch_end(0, training_logs) callbacks.on_epoch_end(epoch, training_logs) if verbose == 1: msg = "<" for key, val in training_logs.items(): msg += f"{key.title()} = {val:.4f} " msg += ">" pbar.set_description(msg) train_data.on_epoch_end() if verbose == 2: print() if model.stop_training: break callbacks.on_train_end() if save_best: self.load(weight_path, as_model=as_model) remove_tf_weights(weight_path) return history
def fit_loop(model, inputs, targets, sample_weights=None, class_weight=None, val_inputs=None, val_targets=None, val_sample_weights=None, batch_size=None, epochs=1, verbose=1, callbacks=None, shuffle=True, callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """Fit function for eager execution. Arguments: model: Instance of the model that is being executed in Eager mode. inputs: List of input arrays. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. class_weight: Optional class-weight array to weight the importance of samples in `inputs` based on the class they belong to, as conveyed by `targets`. val_inputs: Input data for validation. val_targets: Target data for validation. val_sample_weights: Sample weight data for validation. batch_size: Integer batch size or None if unknown. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training shuffle: Whether to shuffle the data at the beginning of each epoch callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with default value of `None`. Returns: `History` object. Raises: ValueError: In case of invalid argument values. """ # Convert training inputs to an EagerIterator inputs, steps_per_epoch = training_utils.convert_to_iterator( x=inputs, y=targets, sample_weights=sample_weights, batch_size=batch_size, steps_per_epoch=steps_per_epoch, epochs=epochs, shuffle=shuffle) # Required for eager execution with backend.learning_phase_scope(1): do_validation = False if val_inputs: do_validation = True num_train_samples = None out_labels = None if model._is_compiled: out_labels = model.metrics_names if do_validation: callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] else: callback_metrics = copy.copy(out_labels) model.history = cbks.History() callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] if verbose: callbacks += [cbks.ProgbarLogger('steps')] callbacks = cbks.CallbackList(callbacks) # it's possible to callback a different model than self # (used by Sequential models) if hasattr(model, 'callback_model') and model.callback_model: callback_model = model.callback_model else: callback_model = model callbacks.set_model(callback_model) callback_params = { 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, 'samples': num_train_samples, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], } if validation_steps: callback_params.update({'validation_steps': validation_steps}) callbacks.set_params(callback_params) for cbk in callbacks: if not val_inputs: cbk.validation_data = [] elif isinstance(val_inputs, iterator_ops.EagerIterator): cbk.validation_data = val_inputs elif val_sample_weights: cbk.validation_data = val_inputs + val_targets + val_sample_weights else: cbk.validation_data = val_inputs + val_targets # validation_data must be set before on_train_begin() is called # so that TensorboardCallback can validate its input callbacks.on_train_begin() callback_model.stop_training = False for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) epoch_logs = {} iterator_fit_loop(model, inputs, class_weight, steps_per_epoch=steps_per_epoch, callback_model=callback_model, out_labels=out_labels, epoch_logs=epoch_logs, val_inputs=val_inputs, val_targets=val_targets, val_sample_weights=val_sample_weights, epochs=epochs, verbose=verbose, callbacks=callbacks, callback_metrics=callback_metrics, validation_steps=validation_steps, do_validation=do_validation, batch_size=batch_size) callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break callbacks.on_train_end() return model.history
def train(self, train_data, val_data=None, **kwargs): cache = self.cache cfg = self.cfg.train cfg.merge_from_dict(kwargs) ckpt_cfg = cfg.ModelCheckpoint es_cfg = cfg.EarlyStopping pb_cfg = cfg.Progbar model = self.model if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `trainer.build()`.' ) if not isinstance(train_data, Sequence): train_data = self.train_sequence(train_data) if cfg.cache_train_data: cache.train_data = train_data validation = val_data is not None if validation: if not isinstance(val_data, Sequence): val_data = self.test_sequence(val_data) if cfg.cache_val_data: cache.val_data = val_data # Setup callbacks callbacks = callbacks_module.CallbackList() history = History() callbacks.append(history) cfg, callbacks = setup_callbacks(cfg, callbacks, validation) callbacks.set_model(model) model.stop_training = False verbose = cfg.verbose if verbose: if verbose <= 2: progbar = Progbar(target=cfg.epochs, width=pb_cfg.width, verbose=verbose) print("Training...") logs = gf.BunchDict() callbacks.on_train_begin() try: for epoch in range(cfg.epochs): if verbose > 2: progbar = Progbar(target=len(train_data), width=pb_cfg.width, verbose=verbose - 2) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) train_logs = self.train_step(train_data) train_data.on_epoch_end() logs.update(train_logs) if validation: valid_logs = self.test_step(val_data) logs.update({("val_" + k): v for k, v in valid_logs.items()}) val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), logs) callbacks.on_epoch_end(epoch, logs) if verbose > 2: print(f"Epoch {epoch+1}/{epochs}") progbar.update(len(train_data), logs.items()) elif verbose: progbar.update(epoch + 1, logs.items()) if model.stop_training: print(f"Early Stopping at Epoch {epoch}", file=sys.stderr) break callbacks.on_train_end() if ckpt_cfg.enabled: if ckpt_cfg.save_weights_only: model.load_weights(ckpt_cfg.path) else: self.model = model.load(ckpt_cfg.path) finally: # to avoid unexpected termination of the model if ckpt_cfg.enabled and ckpt_cfg.remove_weights: self.remove_weights() return history
def fit_loop(model, inputs, targets, epochs=100, verbose=1, callbacks=None, val_inputs=None, val_targets=None, callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """fit function when using DistributionStrategy for training. Arguments: model: Keras Model instance. inputs: List of input arrays. targets: List of target arrays. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_inputs: List of input arrays. val_targets: List of target arrays. callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, model.train_function.outputs, model.train_function.updates_op, model.train_function.session_kwargs) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) # Unwrap all the per device values returned from `call_for_each_tower`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_train_function = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range(len(model.outputs) * current_strategy.num_towers) ] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True if steps_per_epoch is None: raise ValueError('Can only use `validation_steps` ' 'when doing step-wise ' 'training, i.e. `steps_per_epoch` ' 'must be set.') out_labels = model.metrics_names if do_validation: callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] else: callback_metrics = copy.copy(out_labels) model.history = cbks.History() all_callbacks = [ cbks.BaseLogger(stateful_metrics=model.stateful_metric_names) ] if verbose: # We assume that `steps_per_epoch` is always set since we have to use # Datasets. count_mode = 'steps' all_callbacks.append( cbks.ProgbarLogger(count_mode, stateful_metrics=model.stateful_metric_names)) all_callbacks += (callbacks or []) + [model.history] callbacks = cbks.CallbackList(all_callbacks) out_labels = out_labels or [] # We set the callback model to an instance of the `DistributedModel` that we # create in the `compile` call. The `DistributedModel` is initialized with # the first replicated model. We need to set the callback model to a # DistributedModel to allow us to override saving and loading weights when # we checkpoint the model during training. callback_model = model._replicated_model callbacks.set_model(callback_model) callbacks.set_params({ 'epochs': epochs, 'steps': steps_per_epoch, 'samples': None, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], }) callbacks.on_train_begin() callback_model.stop_training = False out_labels = out_labels or [] # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index batch_logs['size'] = 1 callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] # TODO(anjalisridhar): Temporary workaround for aggregating metrics # across towers. Replace with the new metrics module eventually. merged_output = [] # The first output is the total loss. merged_output.append(outs[0]) current_index = 1 num_devices = len(current_strategy._devices) # Each label in `out_labels` corresponds to one set of metrics. The # number of metric values corresponds to the number of devices. We # currently take the mean of the values. for _ in out_labels[1:]: m = np.mean(outs[current_index:current_index + num_devices]) merged_output.append(m) current_index += num_devices for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callback_model.stop_training: break if do_validation: val_outs = test_loop(model, val_inputs, val_targets, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def fit(self, train_data, val_data=None, **kwargs): cache = self.cache cfg = self.cfg.fit cfg.merge_from_dict(kwargs) ckpt_cfg = cfg.ModelCheckpoint es_cfg = cfg.EarlyStopping pb_cfg = cfg.Progbar log_cfg = cfg.Logger if log_cfg.enabled: log_cfg.name = log_cfg.name or self.name logger = gg.utils.setup_logger(output=log_cfg.filepath, name=log_cfg.name) model = self.model if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `trainer.build()`.' ) if not isinstance(train_data, (Sequence, DataLoader, Dataset)): train_data = self.train_loader(train_data) if cfg.cache_train_data: cache.train_data = train_data validation = val_data is not None if validation: if not isinstance(val_data, (Sequence, DataLoader, Dataset)): val_data = self.test_loader(val_data) if cfg.cache_val_data: cache.val_data = val_data # Setup callbacks callbacks = callbacks_module.CallbackList() history = History() callbacks.append(history) cfg, callbacks = setup_callbacks(cfg, callbacks, validation) callbacks.set_model(model) model.stop_training = False verbose = cfg.verbose assert not ( verbose and log_cfg.enabled ), "Progbar and Logger cannot be used together! You must set `verbose=0` when Logger is enabled." if verbose: if verbose <= 2: progbar = Progbar(target=cfg.epochs, width=pb_cfg.width, verbose=verbose) print("Training...") elif log_cfg.enabled: logger.info("Training...") logs = gf.BunchDict() callbacks.on_train_begin() try: for epoch in range(cfg.epochs): if verbose > 2: progbar = Progbar(target=len(train_data), width=pb_cfg.width, verbose=verbose - 2) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) train_logs = self.train_step(train_data) if hasattr(train_data, 'on_epoch_end'): train_data.on_epoch_end() logs.update(train_logs) if validation: valid_logs = self.test_step(val_data) logs.update({("val_" + k): v for k, v in valid_logs.items()}) if hasattr(val_data, 'on_epoch_end'): val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), logs) callbacks.on_epoch_end(epoch, logs) if verbose > 2: print(f"Epoch {epoch+1}/{cfg.epochs}") progbar.update(len(train_data), logs.items()) elif verbose: progbar.update(epoch + 1, logs.items()) elif log_cfg.enabled: logger.info( f"Epoch {epoch+1}/{cfg.epochs}\n{gg.utils.create_table(logs)}" ) if model.stop_training: if log_cfg.enabled: logger.info(f"Early Stopping at Epoch {epoch}") else: print(f"Early Stopping at Epoch {epoch}", file=sys.stderr) break callbacks.on_train_end() if ckpt_cfg.enabled: if ckpt_cfg.save_weights_only: model.load_weights(ckpt_cfg.path) else: self.model = model.load(ckpt_cfg.path) finally: # to avoid unexpected termination of the model if ckpt_cfg.enabled and ckpt_cfg.remove_weights: self.remove_weights() return history
def fit_generator(model, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0): """See docstring for `Model.fit_generator`.""" wait_time = 0.01 # in seconds epoch = initial_epoch do_validation = bool(validation_data) is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: logging.warning( UserWarning('Using a generator with `use_multiprocessing=True`' ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) if steps_per_epoch is None: if is_sequence: steps_per_epoch = len(generator) else: raise ValueError('`steps_per_epoch=None` is only valid for a' ' generator based on the `keras.utils.Sequence`' ' class. Please specify `steps_per_epoch` or use' ' the `keras.utils.Sequence` class.') # python 2 has 'next', 3 has '__next__' # avoid any explicit version checks val_gen = ( hasattr(validation_data, 'next') or hasattr(validation_data, '__next__') or isinstance(validation_data, Sequence)) if (val_gen and not isinstance(validation_data, Sequence) and not validation_steps): raise ValueError('`validation_steps=None` is only valid for a' ' generator based on the `keras.utils.Sequence`' ' class. Please specify `validation_steps` or use' ' the `keras.utils.Sequence` class.') # Prepare display labels. out_labels = model.metrics_names callback_metrics = out_labels + ['val_%s' % n for n in out_labels] # prepare callbacks model.history = cbks.History() callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] if verbose: callbacks += [cbks.ProgbarLogger(count_mode='steps')] callbacks = cbks.CallbackList(callbacks) # it's possible to callback a different model than self: if hasattr(model, 'callback_model') and model.callback_model: callback_model = model.callback_model else: callback_model = model callbacks.set_model(callback_model) callback_params = { 'epochs': epochs, 'steps': steps_per_epoch, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics, } if do_validation: # need to create the test_function before start of the first epoch # because TensorBoard callback on_epoch_begin adds summary to the # list of fetches of the test_function model._make_test_function() # determine the number of validation batches given a generator if validation_steps: callback_params.update({'validation_steps': validation_steps}) elif isinstance(validation_data, Sequence): callback_params.update({'validation_steps': len(validation_data)}) callbacks.set_params(callback_params) enqueuer = None val_enqueuer = None try: if do_validation and not val_gen: # Prepare data for validation if len(validation_data) == 2: val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence val_sample_weight = None elif len(validation_data) == 3: val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence else: raise ValueError( '`validation_data` should be a tuple ' '`(val_x, val_y, val_sample_weight)` ' 'or `(val_x, val_y)`. Found: ' + str(validation_data)) val_x, val_y, val_sample_weights = model._standardize_user_data( val_x, val_y, val_sample_weight) val_data = val_x + val_y + val_sample_weights if model.uses_learning_phase and not isinstance(K.learning_phase(), int): val_data += [0.] for cbk in callbacks: cbk.validation_data = val_data if workers > 0: if is_sequence: enqueuer = OrderedEnqueuer( generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle) else: enqueuer = GeneratorEnqueuer( generator, use_multiprocessing=use_multiprocessing, wait_time=wait_time) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: if is_sequence: output_generator = iter(generator) else: output_generator = generator callback_model.stop_training = False # validation_data must be set before on_train_begin() is called # so that TensorboardCallback can validate its input callbacks.on_train_begin() # Construct epoch logs. epoch_logs = {} while epoch < epochs: for m in model.stateful_metric_functions: m.reset_states() callbacks.on_epoch_begin(epoch) steps_done = 0 batch_index = 0 while steps_done < steps_per_epoch: generator_output = next(output_generator) if not hasattr(generator_output, '__len__'): raise ValueError('Output of generator should be ' 'a tuple `(x, y, sample_weight)` ' 'or `(x, y)`. Found: ' + str(generator_output)) if len(generator_output) == 2: x, y = generator_output sample_weight = None elif len(generator_output) == 3: x, y, sample_weight = generator_output else: raise ValueError('Output of generator should be ' 'a tuple `(x, y, sample_weight)` ' 'or `(x, y)`. Found: ' + str(generator_output)) # build batch logs batch_logs = {} if isinstance(x, list): batch_size = x[0].shape[0] elif isinstance(x, dict): batch_size = list(x.values())[0].shape[0] else: batch_size = x.shape[0] batch_logs['batch'] = batch_index batch_logs['size'] = batch_size callbacks.on_batch_begin(batch_index, batch_logs) outs = model.train_on_batch( x, y, sample_weight=sample_weight, class_weight=class_weight) if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) batch_index += 1 steps_done += 1 # Epoch finished. if steps_done >= steps_per_epoch and do_validation: if val_gen: val_outs = evaluate_generator( model, validation_data, validation_steps, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size) else: # No need for try/except because # data has already been validated. val_outs = model.evaluate( val_x, val_y, batch_size=batch_size, sample_weight=val_sample_weights, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o if callback_model.stop_training: break callbacks.on_epoch_end(epoch, epoch_logs) epoch += 1 if callback_model.stop_training: break finally: try: if enqueuer is not None: enqueuer.stop() finally: if val_enqueuer is not None: val_enqueuer.stop() callbacks.on_train_end() return model.history
def train_v2(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=False, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """ Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`. The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: bool Whether to show the training details. (default :obj: `None`) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using customized `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ if not tf.__version__ >= '2.2.0': raise RuntimeError( f'This method is only work for tensorflow version >= 2.2.0.') # Check if model has been built if self.model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val else: monitor = 'acc' if monitor[:3] == 'val' else monitor model = self.model if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks, add_history=True, add_progbar=True, verbose=verbose, epochs=epochs) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 0)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path makedirs_from_path(weight_path) if not weight_path.endswith('.h5'): weight_path += '.h5' mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) # leave it blank for the future allowed_kwargs = set([]) unknown_kwargs = set(kwargs.keys()) - allowed_kwargs if unknown_kwargs: raise TypeError("Invalid keyword argument(s): %s" % (unknown_kwargs, )) callbacks.on_train_begin() for epoch in range(epochs): callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) train_data.on_epoch_end() training_logs = {'loss': loss, 'acc': accuracy} callbacks.on_train_batch_end(0, training_logs) if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_epoch_end(epoch, training_logs) if model.stop_training: break callbacks.on_train_end() if save_best: self.load(weight_path, as_model=as_model) remove_tf_weights(weight_path) return model.history
def train(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=0, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: int in {0, 1, 2, 3, 4} 'verbose=0': not verbose; 'verbose=1': Progbar (one line, detailed); 'verbose=2': Progbar (one line, omitted); 'verbose=3': Progbar (multi line, detailed); 'verbose=4': Progbar (multi line, omitted); (default :obj: 0) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using custom `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ raise_if_kwargs(kwargs) if not (isinstance(verbose, int) and 0 <= verbose <= 4): raise ValueError("'verbose=0': not verbose" "'verbose=1': Progbar(one line, detailed), " "'verbose=2': Progbar(one line, omitted), " "'verbose=3': Progbar(multi line, detailed), " "'verbose=4': Progbar(multi line, omitted), " f"but got {verbose}") model = self.model # Check if model has been built if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val else: monitor = 'acc' if monitor[:3] == 'val' else monitor if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks) history = History() callbacks.append(history) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 1)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path else: self.weight_path = weight_path makedirs_from_filename(weight_path) if not weight_path.endswith(POSTFIX): weight_path = weight_path + POSTFIX mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) model.stop_training = False callbacks.on_train_begin() if verbose: stateful_metrics = {"acc", 'loss', 'val_acc', 'val_loss', 'time'} if verbose <= 2: progbar = Progbar(target=epochs, verbose=verbose, stateful_metrics=stateful_metrics) print("Training...") begin_time = time.perf_counter() try: for epoch in range(epochs): if verbose > 2: progbar = Progbar(target=len(train_data), verbose=verbose - 2, stateful_metrics=stateful_metrics) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) training_logs = {'loss': loss, 'acc': accuracy} if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), training_logs) callbacks.on_epoch_end(epoch, training_logs) train_data.on_epoch_end() if verbose: time_passed = time.perf_counter() - begin_time training_logs.update({'time': time_passed}) if verbose > 2: print(f"Epoch {epoch+1}/{epochs}") progbar.update(len(train_data), training_logs.items()) else: progbar.update(epoch + 1, training_logs.items()) if model.stop_training: break finally: callbacks.on_train_end() # to avoid unexpected termination of the model if save_best: self.load(weight_path, as_model=as_model) self.remove_weights() return history
def fit_loop(model, inputs, targets, sample_weights=None, batch_size=None, epochs=100, verbose=1, callbacks=None, val_inputs=None, val_targets=None, val_sample_weights=None, shuffle=True, callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """Abstract fit function for arrays of data. Arguments: model: Keras Model instance. inputs: List of input arrays. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. batch_size: Integer batch size or None if unknown. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_inputs: List of input arrays. val_targets: List of target arrays. val_sample_weights: Optional list of sample weight arrays. shuffle: Whether to shuffle the data at the beginning of each epoch callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ model._make_train_function() f = model.train_function sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] if val_inputs: val_ins = val_inputs + val_targets + val_sample_weights + [1] else: ins = inputs + targets + sample_weights if val_inputs: val_ins = val_inputs + val_targets + val_sample_weights if not val_inputs: val_ins = [] do_validation = False if val_inputs: do_validation = True if (steps_per_epoch is None and verbose and inputs and hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): print('Train on %d samples, validate on %d samples' % (inputs[0].shape[0], val_inputs[0].shape[0])) if validation_steps: do_validation = True if steps_per_epoch is None: raise ValueError('Can only use `validation_steps` ' 'when doing step-wise ' 'training, i.e. `steps_per_epoch` ' 'must be set.') out_labels = model.metrics_names if do_validation: callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] else: callback_metrics = copy.copy(out_labels) num_train_samples = training_utils.check_num_samples( ins, batch_size, steps_per_epoch, 'steps_per_epoch') if num_train_samples is not None: index_array = np.arange(num_train_samples) model.history = cbks.History() all_callbacks = [ cbks.BaseLogger(stateful_metrics=model.stateful_metric_names) ] if verbose: if steps_per_epoch is not None: count_mode = 'steps' else: count_mode = 'samples' all_callbacks.append( cbks.ProgbarLogger(count_mode, stateful_metrics=model.stateful_metric_names)) all_callbacks += (callbacks or []) + [model.history] callbacks = cbks.CallbackList(all_callbacks) out_labels = out_labels or [] # it's possible to callback a different model than self # (used by Sequential models) if hasattr(model, 'callback_model') and model.callback_model: callback_model = model.callback_model else: callback_model = model callbacks.set_model(callback_model) callbacks.set_params({ 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, 'samples': num_train_samples, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], }) callbacks.on_train_begin() callback_model.stop_training = False for cbk in callbacks: cbk.validation_data = val_ins # To prevent a slowdown, we find beforehand the arrays that need conversion. feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights indices_for_conversion_to_dense = [] for i in range(len(feed)): if issparse is not None and issparse( ins[i]) and not K.is_sparse(feed[i]): indices_for_conversion_to_dense.append(i) for epoch in range(initial_epoch, epochs): # Reset stateful metrics for m in model.stateful_metric_functions: m.reset_states() # Update callbacks callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index batch_logs['size'] = 1 callbacks.on_batch_begin(step_index, batch_logs) try: outs = f(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callback_model.stop_training: break if do_validation: val_outs = test_loop(model, val_inputs, val_targets, sample_weights=val_sample_weights, batch_size=batch_size, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o else: if shuffle == 'batch': index_array = training_utils.batch_shuffle( index_array, batch_size) elif shuffle: np.random.shuffle(index_array) batches = make_batches(num_train_samples, batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: if isinstance(ins[-1], int): # Do not slice the training phase flag. ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: ins_batch = slice_arrays(ins, batch_ids) except TypeError: raise TypeError('TypeError while preparing batch. ' 'If using HDF5 input data, ' 'pass shuffle="batch".') batch_logs = {} batch_logs['batch'] = batch_index batch_logs['size'] = len(batch_ids) callbacks.on_batch_begin(batch_index, batch_logs) for i in indices_for_conversion_to_dense: ins_batch[i] = ins_batch[i].toarray() outs = f(ins_batch) if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) if callback_model.stop_training: break if batch_index == len(batches) - 1: # Last batch. if do_validation: val_outs = test_loop(model, val_inputs, val_targets, sample_weights=val_sample_weights, batch_size=batch_size, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break callbacks.on_train_end() return model.history
def train(self, train_data, val_data=None, epochs=200, early_stopping=None, verbose=1, save_best=True, ckpt_path=None, as_model=False, monitor='val_accuracy', early_stop_metric='val_loss', callbacks=None, **kwargs): """Train the model for the input `train_data` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- train_data: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of objects (or sequence) that will be used during training. val_data: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of objects (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: int in {0, 1, 2, 3, 4} 'verbose=0': not verbose; 'verbose=1': Progbar (one line, detailed); 'verbose=2': Progbar (one line, omitted); 'verbose=3': Progbar (multi line, detailed); 'verbose=4': Progbar (multi line, omitted); (default :obj: 1) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) ckpt_path: String or None The path of saved weights/model. (default to current path.) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using custom `layer` or `loss` and so on. monitor: String One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, it determines which metric will be used for `save_best`. (default :obj: `val_accuracy`) early_stop_metric: String One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ raise_if_kwargs(kwargs) if not (isinstance(verbose, int) and 0 <= verbose <= 4): raise ValueError("'verbose=0': not verbose" "'verbose=1': Progbar(one line, detailed), " "'verbose=2': Progbar(one line, omitted), " "'verbose=3': Progbar(multi line, detailed), " "'verbose=4': Progbar(multi line, omitted), " f"but got {verbose}") model = self.model # Check if model has been built if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) metrics_names = getattr(model, "metrics_names", None) # FIXME: This would return '[]' for tensorflow>=2.2.0 # See <https://github.com/tensorflow/tensorflow/issues/37990> # metrics_names = ['loss', 'accuracy'] if not metrics_names: raise RuntimeError(f"Please specify the attribute 'metrics_names' for the model.") if not isinstance(train_data, Sequence): train_data = self.train_sequence(train_data) self.train_data = train_data validation = val_data is not None if validation: if not isinstance(val_data, Sequence): val_data = self.test_sequence(val_data) self.val_data = val_data metrics_names = metrics_names + ["val_" + metric for metric in metrics_names] if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks) history = History() callbacks.append(history) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 1)) callbacks.append(es_callback) if save_best: if not ckpt_path: ckpt_path = self.ckpt_path else: self.ckpt_path = ckpt_path makedirs_from_filepath(ckpt_path) if not ckpt_path.endswith(gg.file_ext()): ckpt_path = ckpt_path + gg.file_ext() if monitor not in metrics_names: monitor = metrics_names[-1] warnings.warn(f"'{monitor}' are not included in the metrics names. default to '{monitor}'.", UserWarning) mc_callback = ModelCheckpoint(ckpt_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) model.stop_training = False if verbose: if verbose <= 2: progbar = Progbar(target=epochs, width=20, verbose=verbose) print("Training...") logs = BunchDict() callbacks.on_train_begin() try: for epoch in range(epochs): if verbose > 2: progbar = Progbar(target=len(train_data), width=20, verbose=verbose - 2) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) train_logs = self.train_step(train_data) train_data.on_epoch_end() logs.update(train_logs) if validation: valid_logs = self.test_step(val_data) logs.update({("val_" + k): v for k, v in valid_logs.items()}) val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), logs) callbacks.on_epoch_end(epoch, logs) if verbose > 2: print(f"Epoch {epoch+1}/{epochs}") progbar.update(len(train_data), logs.items()) elif verbose: progbar.update(epoch + 1, logs.items()) if model.stop_training: print(f"Early Stopping at Epoch {epoch}", file=sys.stderr) break callbacks.on_train_end() self.load(ckpt_path, as_model=as_model) finally: # to avoid unexpected termination of the model self.remove_weights() return history
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, auto_switch=True, retry_fit=True, absorb=True, train_after_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False, revert_after_fit=False): """ Custom fit function for the context model auto_switch: Enable/disable autonomous context switching train_after_switch: retry_fit: Locate the next fitting context by re-performing fit. absorb: Reset the switch sequence counter upon successful training. This is mainly used to maintain switch sequencing for temporally-extended tasks revert_after_fit This is a debug parameter to revert weights after performing a fit. This is used to calculate the context deltas without incorrectly learning while auto switching is disabled """ training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() self.initialize_fit() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched_during_epoch = False # Indicate if the model has attempted at least one switch during this epoch switched = True # Indicate if the model switched on the most recent fit iteration weights = backend.batch_get_value(self.trainable_variables) # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits while switched and (retry_fit or not switched_during_epoch): self.initialize_epoch(epoch) iterator = iter(dataset) # Perform a fit call with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, auto_switch, absorb, retry_fit, verbose) switched_during_epoch |= switched # If a switch occurred, we need to restore the weights if switched or (switched_during_epoch and not train_after_switch ) or revert_after_fit: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history
def train_model(name, g_train, d_train, sampler, generator, samples_per_epoch, epochs, z_dim=100, verbose=1, callbacks=[], saver=None): """ Main training loop. Modified version of fit_generator. """ self = {} epoch = 0 counter = 0 out_labels = ['g_loss', 'd_loss', 'd_loss_fake', 'd_loss_legit', 'time'] # self.metrics_names callback_metrics = out_labels + ['val_' + n for n in out_labels] # prepare callbacks history = cbks.History() callbacks = [cbks.BaseLogger()] + callbacks + [history] if verbose: callbacks += [cbks.ProgbarLogger()] callbacks = cbks.CallbackList(callbacks) callback_params = { 'epochs': epochs, 'samples': samples_per_epoch, 'verbose': verbose, 'metrics': callback_metrics, } callbacks.set_params(callback_params) callbacks.on_train_begin() while epoch < epochs: callbacks.on_epoch_begin(epoch) samples_seen = 0 batch_index = 0 while samples_seen < samples_per_epoch: z, x = next(generator) # build batch logs batch_logs = {} if type(x) is list: batch_size = len(x[0]) elif type(x) is dict: batch_size = len(list(x.values())[0]) else: batch_size = len(x) batch_logs['batch'] = batch_index batch_logs['size'] = batch_size callbacks.on_batch_begin(batch_index, batch_logs) t1 = time.time() d_losses = d_train(x, z, counter) z, x = next(generator) g_loss, samples, xs = g_train(x, z, counter) outs = (g_loss, ) + d_losses + (time.time() - t1, ) counter += 1 # save samples if batch_index % 100 == 0: join_image = np.zeros_like( np.concatenate([samples[:64], xs[:64]], axis=0)) for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])): join_image[j * 2] = i1 join_image[j * 2 + 1] = i2 save_images( join_image, [8 * 2, 8], './outputs/samples_%s/train_%s_%s.png' % (name, epoch, batch_index)) samples, xs = sampler(z, x) join_image = np.zeros_like( np.concatenate([samples[:64], xs[:64]], axis=0)) for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])): join_image[j * 2] = i1 join_image[j * 2 + 1] = i2 save_images( join_image, [8 * 2, 8], './outputs/samples_%s/test_%s_%s.png' % (name, epoch, batch_index)) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) # construct epoch logs epoch_logs = {} batch_index += 1 samples_seen += batch_size if saver is not None: saver(epoch) callbacks.on_epoch_end(epoch, epoch_logs) epoch += 1 # _stop.set() callbacks.on_train_end()