Exemple #1
0
def get(base_dir, cfg, model, train_steps, **params):
    callbacks = [
        K.callbacks.TerminateOnNaN(),
    ]
    if cfg.ckpt:
        callbacks.append(
            K.callbacks.ModelCheckpoint(base_dir +
                                        f"/{cfg.tag}/ckpt/{cfg.ckpt}",
                                        save_weights_only=True,
                                        verbose=1))
    if cfg.best_ckpt:
        callbacks.append(
            K.callbacks.ModelCheckpoint(base_dir +
                                        f"/{cfg.tag}/ckpt/{cfg.best_ckpt}",
                                        save_best_only=True,
                                        save_weights_only=True,
                                        verbose=1))
    if cfg.tensorboard:
        callbacks.append(
            K.callbacks.TensorBoard(base_dir + f"/{cfg.tag}/{cfg.tensorboard}",
                                    write_graph=False))
    if cfg.lrp:
        from . import optimizer
        callbacks.append(optimizer.lr_callback(cfg))

    final_params = {
        "verbose": cfg.verbose,
        "epochs": cfg.total_epochs,
        "steps": train_steps
    }
    return callbacks_module.CallbackList(callbacks,
                                         add_history=True,
                                         add_progbar=cfg.verbose != 0,
                                         model=model,
                                         **params)
Exemple #2
0
    def fit(
        self,
        x=None,
        y=None,
        validation_data=None,
        epochs=1,
        verbose=0,
        callbacks=None,
        **kwargs,
    ):
        """ If the optimizer is genetic, the fitting procedure consists on executing `run_step` for
        the given number of epochs.
        """
        if self.is_genetic:
            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=False,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                )
            callbacks.on_train_begin()

            result = self.perform_genetic_fit(
                x=x,
                y=y,
                epochs=epochs,
                verbose=verbose,
                validation_data=validation_data,
                callbacks=callbacks,
            )

        else:
            result = super().fit(
                x=x,
                y=y,
                validation_data=validation_data,
                epochs=epochs,
                verbose=verbose,
                callbacks=callbacks,
                **kwargs,
            )
        return result
def accuracy_aware_fit(cls_instance,
                       train_dataset,
                       compression_ctrl,
                       nncf_config,
                       callbacks,
                       initial_epoch,
                       uncompressed_model_accuracy,
                       steps_per_epoch=None,
                       batch_size=None,
                       tensorboard_writer=None,
                       log_dir=None,
                       validation_data=None,
                       validation_steps=None,
                       result_dict_to_val_metric_fn=None,
                       **kwargs):
    if result_dict_to_val_metric_fn is None:
        result_dict_to_val_metric_fn = lambda metric: metric

    with cls_instance.distribute_strategy.scope(), \
        training_utils.RespectCompiledTrainableState(cls_instance):
        # pylint: disable=protected-access
        data_handler = data_adapter.DataHandler(
            x=train_dataset,
            y=None,
            sample_weight=None,
            batch_size=batch_size,
            steps_per_epoch=steps_per_epoch,
            initial_epoch=initial_epoch,
            epochs=1,
            shuffle=True,
            class_weight=None,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            model=cls_instance,
            steps_per_execution=cls_instance._steps_per_execution)

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(
                callbacks,
                add_history=True,
                model=cls_instance,
                epochs=1,
                verbose=1,
                add_progbar=True,
                steps=data_handler.inferred_steps)

    def train_epoch_fn(compression_ctrl, model, epoch):
        model.reset_metrics()

        if model.train_function is None:
            model.train_function = model.make_train_function()
        _, iterator = next(data_handler.enumerate_epochs())

        callbacks.on_epoch_begin(epoch)
        with data_handler.catch_stop_iteration():
            for step in data_handler.steps():
                with trace.Trace('train',
                                 epoch_num=epoch,
                                 step_num=step,
                                 batch_size=None,
                                 _r=1):
                    callbacks.on_train_batch_begin(step)
                    tmp_logs = model.train_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
                    logs = tmp_logs
                    end_step = step + data_handler.step_increment
                    callbacks.on_train_batch_end(end_step, logs)
                    if model.stop_training:
                        break

        if logs is None:
            raise ValueError('Expect x to be a non-empty array or dataset.')
        epoch_logs = copy.copy(logs)
        callbacks.on_epoch_end(epoch, epoch_logs)

    if validation_data is None:
        validation_data = train_dataset

    def validate_fn(model, epoch=None):
        val_x, val_y, val_sample_weight = (
            data_adapter.unpack_x_y_sample_weight(validation_data))
        val_logs = model.evaluate(x=val_x,
                                  y=val_y,
                                  sample_weight=val_sample_weight,
                                  batch_size=None,
                                  steps=validation_steps,
                                  callbacks=callbacks,
                                  return_dict=True)
        return result_dict_to_val_metric_fn(val_logs)

    callbacks.on_train_begin()
    cls_instance.original_model_accuracy = uncompressed_model_accuracy
    acc_aware_training_loop = create_accuracy_aware_training_loop(
        nncf_config, compression_ctrl)
    cls_instance = acc_aware_training_loop.run(
        cls_instance,
        train_epoch_fn=train_epoch_fn,
        validate_fn=validate_fn,
        tensorboard_writer=tensorboard_writer,
        log_dir=log_dir)
    callbacks.on_train_end()
Exemple #4
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):
        """ From tf.keras.Model. """
        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')
        training._disallow_inside_tf_function('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight), validation_split=validation_split))

        if validation_data:
            val_x, val_y, val_sample_weight = (
                data_adapter.unpack_x_y_sample_weight(validation_data))

        with self.distribute_strategy.scope(), \
             training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = data_adapter.DataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self,
                steps_per_execution=self._steps_per_execution)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=verbose != 0,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            self._train_counter.assign(0)
            callbacks.on_train_begin()
            training_logs = None
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (  # pylint: disable=protected-access
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                with data_handler.catch_stop_iteration():
                    if self._update_cycle > 1:
                        self._grad_accumulator.reset()
                    for step in data_handler.steps():
                        with trace.Trace(
                            'TraceContext',
                            graph_type='train',
                            epoch_num=epoch,
                            step_num=step,
                            batch_size=batch_size):
                            callbacks.on_train_batch_begin(step)
                            if self._update_cycle > 1:
                                for _ in range(self._update_cycle - 1):
                                    self.accumulate_function(iterator)
                            tmp_logs = train_function(iterator)
                            if data_handler.should_sync:
                                context.async_wait()
                            logs = tmp_logs  # No error, now safe to assign to logs.
                            end_step = step + data_handler.step_increment
                            callbacks.on_train_batch_end(end_step, logs)
                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    # Create data_handler for evaluation and cache it.
                    if getattr(self, '_eval_data_handler', None) is None:
                        self._eval_data_handler = data_adapter.DataHandler(
                            x=val_x,
                            y=val_y,
                            sample_weight=val_sample_weight,
                            batch_size=validation_batch_size or batch_size,
                            steps_per_epoch=validation_steps,
                            initial_epoch=0,
                            epochs=1,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            model=self,
                            steps_per_execution=self._steps_per_execution)
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                training_logs = epoch_logs
                if self.stop_training:
                    break

            # If eval data_hanlder exists, delete it after all epochs are done.
            if getattr(self, '_eval_data_handler', None) is not None:
                del self._eval_data_handler
            callbacks.on_train_end(logs=training_logs)
            return self.history
Exemple #5
0
    def build_callbacks(self, conf, callbacks_list):
        '''
        The purpose of the method is to set up logging and history. It is based
        on Keras Callbacks
        https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py

        Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping.
        Other possible callbacks to add in future: RemoteMonitor,
        LearningRateScheduler

        Argument list:
        - conf: There is a "callbacks" section in conf.yaml file.

        Relevant parameters are:
        - list: Parameter specifying additional callbacks, read
        in the driver script and passed as an argument of type  list (see next
        arg)

        - metrics: List of quantities monitored during training and validation

        - mode: one of {auto, min, max}. The decision to overwrite the current
        save file is made based on either the maximization or the minimization
        of the monitored quantity. For val_acc, this should be max, for
        val_loss this should be min, etc. In auto mode, the direction is
        automatically inferred from the name of the monitored quantity.

        - monitor: Quantity used for early stopping, has to
        be from the list of metrics

        - patience: Number of epochs used to decide on whether to apply early
          stopping or continue training

        - callbacks_list: uses callbacks.list configuration parameter,
          specifies the list of additional callbacks Returns: modified list of
          callbacks

        '''

        mode = conf['callbacks']['mode']
        monitor = conf['callbacks']['monitor']
        patience = conf['callbacks']['patience']
        csvlog_save_path = conf['paths']['csvlog_save_path']
        # CSV callback is on by default
        if not os.path.exists(csvlog_save_path):
            os.makedirs(csvlog_save_path)

        callbacks_list = conf['callbacks']['list']
        callbacks = [cbks.BaseLogger()]
        callbacks += [self.history]
        callbacks += [
            cbks.CSVLogger("{}callbacks-{}.log".format(
                csvlog_save_path,
                datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
        ]

        if "earlystop" in callbacks_list:
            callbacks += [
                cbks.EarlyStopping(patience=patience,
                                   monitor=monitor,
                                   mode=mode)
            ]
        if "lr_scheduler" in callbacks_list:
            pass

        return cbks.CallbackList(callbacks)
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            dynamic_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                weights = backend.batch_get_value(self.trainable_variables)
                while switched:
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, dynamic_switch, verbose)
                        # If a switch occurred, we need to restore the weights
                        if switched:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                if self.accumulate_gradients:
                    self.optimizer.apply_gradients(
                        zip(self.accumulated_gradients,
                            self.trainable_variables))

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
Exemple #7
0
    def train(self, train_data, val_data=None, **kwargs):
        cache = self.cache
        cfg = self.cfg.train
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        cache.train_data = train_data

        validation = val_data is not None

        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            cache.val_data = val_data
        elif ckpt_cfg.enabled and ckpt_cfg.monitor.startswith("val_"):
            ckpt_cfg.monitor = ckpt_cfg.monitor[4:]
            warnings.warn(
                f"The metric 'val_{ckpt_cfg.monitor}' is invalid without validation "
                f"and has been automatically replaced with '{ckpt_cfg.monitor}'.",
                UserWarning)

        callbacks = callbacks_module.CallbackList()

        history = History()
        callbacks.append(history)

        if es_cfg.enabled:
            assert es_cfg.monitor.startswith("val")
            es_callback = EarlyStopping(monitor=es_cfg.monitor,
                                        patience=es_cfg.monitor,
                                        mode=es_cfg.mode,
                                        verbose=es_cfg.verbose)
            callbacks.append(es_callback)

        if ckpt_cfg.enabled:
            if not ckpt_cfg.path.endswith(gg.file_ext()):
                ckpt_cfg.path += gg.file_ext()
            makedirs_from_filepath(ckpt_cfg.path)

            mc_callback = ModelCheckpoint(
                ckpt_cfg.path,
                monitor=ckpt_cfg.monitor,
                save_best_only=ckpt_cfg.save_best_only,
                save_weights_only=ckpt_cfg.save_weights_only,
                verbose=ckpt_cfg.vervose)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v
                                 for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)

        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
Exemple #8
0
    def train(self,
              idx_train,
              idx_val=None,
              epochs=200,
              early_stopping=None,
              verbose=0,
              save_best=True,
              weight_path=None,
              as_model=False,
              monitor='val_acc',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2}
                'verbose=0': not verbose; 
                'verbose=1': tqdm verbose; 
                'verbose=2': tensorflow probar verbose;        
            (default :obj: 0)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using customized `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        if not verbose in {0, 1, 2}:
            raise ValueError(
                "'verbose=0': not verbose; 'verbose=1': tqdm verbose; "
                "'verbose=2': tensorflow probar verbose; "
                f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        # TODO: add metric names in `model`
        metric_names = ['loss', 'acc']
        callback_metrics = metric_names
        model.stop_training = False

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
            callback_metrics = copy.copy(metric_names)
            callback_metrics += ['val_' + n for n in metric_names]
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = tf_History()
        callbacks.append(history)

        if verbose == 2:
            callbacks.append(ProgbarLogger(stateful_metrics=metric_names[1:]))

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path

            makedirs_from_path(weight_path)

            if not weight_path.endswith('.h5'):
                weight_path = weight_path + '.h5'

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        # TODO: to be improved
        callback_params = {
            'batch_size': None,
            'epochs': epochs,
            'steps': 1,
            'samples': 1,
            'verbose': verbose == 2,
            'do_validation': validation,
            'metrics': callback_metrics,
        }
        callbacks.set_params(callback_params)
        raise_if_kwargs(kwargs)

        callbacks.on_train_begin()

        if verbose == 1:
            pbar = tqdm(range(1, epochs + 1))
        else:
            pbar = range(epochs)

        for epoch in pbar:
            callbacks.on_epoch_begin(epoch)

            callbacks.on_train_batch_begin(0)
            loss, accuracy = self.train_step(train_data)

            training_logs = {'loss': loss, 'acc': accuracy}

            if validation:
                val_loss, val_accuracy = self.test_step(val_data)
                training_logs.update({
                    'val_loss': val_loss,
                    'val_acc': val_accuracy
                })
                val_data.on_epoch_end()
            callbacks.on_train_batch_end(0, training_logs)
            callbacks.on_epoch_end(epoch, training_logs)

            if verbose == 1:
                msg = "<"
                for key, val in training_logs.items():
                    msg += f"{key.title()} = {val:.4f} "
                msg += ">"
                pbar.set_description(msg)
            train_data.on_epoch_end()

            if verbose == 2:
                print()

            if model.stop_training:
                break

        callbacks.on_train_end()

        if save_best:
            self.load(weight_path, as_model=as_model)
            remove_tf_weights(weight_path)

        return history
Exemple #9
0
def fit_loop(model,
             inputs,
             targets,
             sample_weights=None,
             class_weight=None,
             val_inputs=None,
             val_targets=None,
             val_sample_weights=None,
             batch_size=None,
             epochs=1,
             verbose=1,
             callbacks=None,
             shuffle=True,
             callback_metrics=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """Fit function for eager execution.

  Arguments:
      model: Instance of the model that is being executed in Eager mode.
      inputs: List of input arrays.
      targets: List of target arrays.
      sample_weights: Optional list of sample weight arrays.
      class_weight: Optional class-weight array to weight the importance of
          samples in `inputs` based on the class they belong to, as conveyed by
          `targets`.
      val_inputs: Input data for validation.
      val_targets: Target data for validation.
      val_sample_weights: Sample weight data for validation.
      batch_size: Integer batch size or None if unknown.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      shuffle: Whether to shuffle the data at the beginning of each epoch
      callback_metrics: List of strings, the display names of the metrics
          passed to the callbacks. They should be the
          concatenation of list the display names of the outputs of
           `f` and the list of display names of the outputs of `f_val`.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for (only if doing
        validation from data tensors). Ignored with default value of `None`.

  Returns:
      `History` object.

  Raises:
    ValueError: In case of invalid argument values.
  """
    # Convert training inputs to an EagerIterator
    inputs, steps_per_epoch = training_utils.convert_to_iterator(
        x=inputs,
        y=targets,
        sample_weights=sample_weights,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        shuffle=shuffle)
    # Required for eager execution
    with backend.learning_phase_scope(1):
        do_validation = False
        if val_inputs:
            do_validation = True

        num_train_samples = None
        out_labels = None
        if model._is_compiled:
            out_labels = model.metrics_names
            if do_validation:
                callback_metrics = copy.copy(out_labels) + [
                    'val_' + n for n in out_labels
                ]
            else:
                callback_metrics = copy.copy(out_labels)

        model.history = cbks.History()
        callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
        if verbose:
            callbacks += [cbks.ProgbarLogger('steps')]
        callbacks = cbks.CallbackList(callbacks)

        # it's possible to callback a different model than self
        # (used by Sequential models)
        if hasattr(model, 'callback_model') and model.callback_model:
            callback_model = model.callback_model
        else:
            callback_model = model

        callbacks.set_model(callback_model)

        callback_params = {
            'batch_size': batch_size,
            'epochs': epochs,
            'steps': steps_per_epoch,
            'samples': num_train_samples,
            'verbose': verbose,
            'do_validation': do_validation,
            'metrics': callback_metrics or [],
        }
        if validation_steps:
            callback_params.update({'validation_steps': validation_steps})
        callbacks.set_params(callback_params)

        for cbk in callbacks:
            if not val_inputs:
                cbk.validation_data = []
            elif isinstance(val_inputs, iterator_ops.EagerIterator):
                cbk.validation_data = val_inputs
            elif val_sample_weights:
                cbk.validation_data = val_inputs + val_targets + val_sample_weights
            else:
                cbk.validation_data = val_inputs + val_targets
        # validation_data must be set before on_train_begin() is called
        # so that TensorboardCallback can validate its input
        callbacks.on_train_begin()
        callback_model.stop_training = False

        for epoch in range(initial_epoch, epochs):
            callbacks.on_epoch_begin(epoch)
            epoch_logs = {}
            iterator_fit_loop(model,
                              inputs,
                              class_weight,
                              steps_per_epoch=steps_per_epoch,
                              callback_model=callback_model,
                              out_labels=out_labels,
                              epoch_logs=epoch_logs,
                              val_inputs=val_inputs,
                              val_targets=val_targets,
                              val_sample_weights=val_sample_weights,
                              epochs=epochs,
                              verbose=verbose,
                              callbacks=callbacks,
                              callback_metrics=callback_metrics,
                              validation_steps=validation_steps,
                              do_validation=do_validation,
                              batch_size=batch_size)
            callbacks.on_epoch_end(epoch, epoch_logs)
            if callback_model.stop_training:
                break
    callbacks.on_train_end()
    return model.history
Exemple #10
0
    def train(self, train_data, val_data=None, **kwargs):
        cache = self.cache
        cfg = self.cfg.train
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        if cfg.cache_train_data:
            cache.train_data = train_data

        validation = val_data is not None
        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            if cfg.cache_val_data:
                cache.val_data = val_data

        # Setup callbacks
        callbacks = callbacks_module.CallbackList()
        history = History()
        callbacks.append(history)
        cfg, callbacks = setup_callbacks(cfg, callbacks, validation)
        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)

        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
Exemple #11
0
def fit_loop(model,
             inputs,
             targets,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_inputs=None,
             val_targets=None,
             callback_metrics=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """fit function when using DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_inputs: List of input arrays.
      val_targets: List of target arrays.
      callback_metrics: List of strings, the display names of the metrics
          passed to the callbacks. They should be the
          concatenation of list the display names of the outputs of
           `f` and the list of display names of the outputs of `f_val`.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_train_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        # Unwrap all the per device values returned from `call_for_each_tower`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    # Create a train function that is composed of all the parameters above.
    distributed_train_function = K.Function(all_inputs,
                                            all_outputs,
                                            updates=all_updates,
                                            name='distributed_train_function',
                                            **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [1]
    else:
        ins = dataset_inputs + dataset_targets

    do_validation = False
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')
    out_labels = model.metrics_names
    if do_validation:
        callback_metrics = copy.copy(out_labels) + [
            'val_' + n for n in out_labels
        ]
    else:
        callback_metrics = copy.copy(out_labels)

    model.history = cbks.History()
    all_callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        # We assume that `steps_per_epoch` is always set since we have to use
        # Datasets.
        count_mode = 'steps'

        all_callbacks.append(
            cbks.ProgbarLogger(count_mode,
                               stateful_metrics=model.stateful_metric_names))
    all_callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(all_callbacks)
    out_labels = out_labels or []

    # We set the callback model to an instance of the `DistributedModel` that we
    # create in the  `compile` call. The `DistributedModel` is initialized with
    # the first replicated model. We need to set the callback model to a
    # DistributedModel to allow us to override saving and loading weights when
    # we checkpoint the model during training.
    callback_model = model._replicated_model

    callbacks.set_model(callback_model)

    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': None,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks.on_train_begin()
    callback_model.stop_training = False

    out_labels = out_labels or []

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        if steps_per_epoch is not None:
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {}
                batch_logs['batch'] = step_index
                batch_logs['size'] = 1
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_train_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]

                # TODO(anjalisridhar): Temporary workaround for aggregating metrics
                # across towers. Replace with the new metrics module eventually.
                merged_output = []
                # The first output is the total loss.
                merged_output.append(outs[0])
                current_index = 1
                num_devices = len(current_strategy._devices)
                # Each label in `out_labels` corresponds to one set of metrics. The
                # number of metric values corresponds to the number of devices. We
                # currently take the mean of the values.
                for _ in out_labels[1:]:
                    m = np.mean(outs[current_index:current_index +
                                     num_devices])
                    merged_output.append(m)
                    current_index += num_devices

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callback_model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_inputs,
                                     val_targets,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
    return model.history
Exemple #12
0
    def fit(self, train_data, val_data=None, **kwargs):

        cache = self.cache
        cfg = self.cfg.fit
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar
        log_cfg = cfg.Logger

        if log_cfg.enabled:
            log_cfg.name = log_cfg.name or self.name
            logger = gg.utils.setup_logger(output=log_cfg.filepath,
                                           name=log_cfg.name)

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, (Sequence, DataLoader, Dataset)):
            train_data = self.train_loader(train_data)

        if cfg.cache_train_data:
            cache.train_data = train_data

        validation = val_data is not None
        if validation:
            if not isinstance(val_data, (Sequence, DataLoader, Dataset)):
                val_data = self.test_loader(val_data)
            if cfg.cache_val_data:
                cache.val_data = val_data

        # Setup callbacks
        callbacks = callbacks_module.CallbackList()
        history = History()
        callbacks.append(history)
        cfg, callbacks = setup_callbacks(cfg, callbacks, validation)
        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        assert not (
            verbose and log_cfg.enabled
        ), "Progbar and Logger cannot be used together! You must set `verbose=0` when Logger is enabled."

        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")
        elif log_cfg.enabled:
            logger.info("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                if hasattr(train_data, 'on_epoch_end'):
                    train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v
                                 for k, v in valid_logs.items()})
                    if hasattr(val_data, 'on_epoch_end'):
                        val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{cfg.epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())
                elif log_cfg.enabled:
                    logger.info(
                        f"Epoch {epoch+1}/{cfg.epochs}\n{gg.utils.create_table(logs)}"
                    )

                if model.stop_training:
                    if log_cfg.enabled:
                        logger.info(f"Early Stopping at Epoch {epoch}")
                    else:
                        print(f"Early Stopping at Epoch {epoch}",
                              file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)
        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
def fit_generator(model,
                  generator,
                  steps_per_epoch=None,
                  epochs=1,
                  verbose=1,
                  callbacks=None,
                  validation_data=None,
                  validation_steps=None,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=True,
                  initial_epoch=0):
  """See docstring for `Model.fit_generator`."""
  wait_time = 0.01  # in seconds
  epoch = initial_epoch

  do_validation = bool(validation_data)

  is_sequence = isinstance(generator, Sequence)
  if not is_sequence and use_multiprocessing and workers > 1:
    logging.warning(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the`keras.utils.Sequence'
                    ' class.'))
  if steps_per_epoch is None:
    if is_sequence:
      steps_per_epoch = len(generator)
    else:
      raise ValueError('`steps_per_epoch=None` is only valid for a'
                       ' generator based on the `keras.utils.Sequence`'
                       ' class. Please specify `steps_per_epoch` or use'
                       ' the `keras.utils.Sequence` class.')

  # python 2 has 'next', 3 has '__next__'
  # avoid any explicit version checks
  val_gen = (
      hasattr(validation_data, 'next') or
      hasattr(validation_data, '__next__') or
      isinstance(validation_data, Sequence))
  if (val_gen and not isinstance(validation_data, Sequence) and
      not validation_steps):
    raise ValueError('`validation_steps=None` is only valid for a'
                     ' generator based on the `keras.utils.Sequence`'
                     ' class. Please specify `validation_steps` or use'
                     ' the `keras.utils.Sequence` class.')

  # Prepare display labels.
  out_labels = model.metrics_names
  callback_metrics = out_labels + ['val_%s' % n for n in out_labels]

  # prepare callbacks
  model.history = cbks.History()
  callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
  if verbose:
    callbacks += [cbks.ProgbarLogger(count_mode='steps')]
  callbacks = cbks.CallbackList(callbacks)

  # it's possible to callback a different model than self:
  if hasattr(model, 'callback_model') and model.callback_model:
    callback_model = model.callback_model
  else:
    callback_model = model
  callbacks.set_model(callback_model)

  callback_params = {
      'epochs': epochs,
      'steps': steps_per_epoch,
      'verbose': verbose,
      'do_validation': do_validation,
      'metrics': callback_metrics,
  }
  if do_validation:
    # need to create the test_function before start of the first epoch
    # because TensorBoard callback on_epoch_begin adds summary to the
    # list of fetches of the test_function
    model._make_test_function()
    # determine the number of validation batches given a generator
    if validation_steps:
      callback_params.update({'validation_steps': validation_steps})
    elif isinstance(validation_data, Sequence):
      callback_params.update({'validation_steps': len(validation_data)})
  callbacks.set_params(callback_params)

  enqueuer = None
  val_enqueuer = None

  try:
    if do_validation and not val_gen:
      # Prepare data for validation
      if len(validation_data) == 2:
        val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
        val_sample_weight = None
      elif len(validation_data) == 3:
        val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
      else:
        raise ValueError(
            '`validation_data` should be a tuple '
            '`(val_x, val_y, val_sample_weight)` '
            'or `(val_x, val_y)`. Found: ' + str(validation_data))
      val_x, val_y, val_sample_weights = model._standardize_user_data(
          val_x, val_y, val_sample_weight)
      val_data = val_x + val_y + val_sample_weights
      if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        val_data += [0.]
      for cbk in callbacks:
        cbk.validation_data = val_data

    if workers > 0:
      if is_sequence:
        enqueuer = OrderedEnqueuer(
            generator,
            use_multiprocessing=use_multiprocessing,
            shuffle=shuffle)
      else:
        enqueuer = GeneratorEnqueuer(
            generator,
            use_multiprocessing=use_multiprocessing,
            wait_time=wait_time)
      enqueuer.start(workers=workers, max_queue_size=max_queue_size)
      output_generator = enqueuer.get()
    else:
      if is_sequence:
        output_generator = iter(generator)
      else:
        output_generator = generator

    callback_model.stop_training = False
    # validation_data must be set before on_train_begin() is called
    # so that TensorboardCallback can validate its input
    callbacks.on_train_begin()
    # Construct epoch logs.
    epoch_logs = {}
    while epoch < epochs:
      for m in model.stateful_metric_functions:
        m.reset_states()
      callbacks.on_epoch_begin(epoch)
      steps_done = 0
      batch_index = 0
      while steps_done < steps_per_epoch:
        generator_output = next(output_generator)

        if not hasattr(generator_output, '__len__'):
          raise ValueError('Output of generator should be '
                           'a tuple `(x, y, sample_weight)` '
                           'or `(x, y)`. Found: ' + str(generator_output))

        if len(generator_output) == 2:
          x, y = generator_output
          sample_weight = None
        elif len(generator_output) == 3:
          x, y, sample_weight = generator_output
        else:
          raise ValueError('Output of generator should be '
                           'a tuple `(x, y, sample_weight)` '
                           'or `(x, y)`. Found: ' + str(generator_output))
        # build batch logs
        batch_logs = {}
        if isinstance(x, list):
          batch_size = x[0].shape[0]
        elif isinstance(x, dict):
          batch_size = list(x.values())[0].shape[0]
        else:
          batch_size = x.shape[0]
        batch_logs['batch'] = batch_index
        batch_logs['size'] = batch_size
        callbacks.on_batch_begin(batch_index, batch_logs)

        outs = model.train_on_batch(
            x, y, sample_weight=sample_weight, class_weight=class_weight)

        if not isinstance(outs, list):
          outs = [outs]
        for l, o in zip(out_labels, outs):
          batch_logs[l] = o

        callbacks.on_batch_end(batch_index, batch_logs)

        batch_index += 1
        steps_done += 1

        # Epoch finished.
        if steps_done >= steps_per_epoch and do_validation:
          if val_gen:
            val_outs = evaluate_generator(
                model,
                validation_data,
                validation_steps,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                max_queue_size=max_queue_size)
          else:
            # No need for try/except because
            # data has already been validated.
            val_outs = model.evaluate(
                val_x,
                val_y,
                batch_size=batch_size,
                sample_weight=val_sample_weights,
                verbose=0)
          if not isinstance(val_outs, list):
            val_outs = [val_outs]
          # Same labels assumed.
          for l, o in zip(out_labels, val_outs):
            epoch_logs['val_' + l] = o

        if callback_model.stop_training:
          break

      callbacks.on_epoch_end(epoch, epoch_logs)
      epoch += 1
      if callback_model.stop_training:
        break

  finally:
    try:
      if enqueuer is not None:
        enqueuer.stop()
    finally:
      if val_enqueuer is not None:
        val_enqueuer.stop()

  callbacks.on_train_end()
  return model.history
Exemple #14
0
    def train_v2(self,
                 idx_train,
                 idx_val=None,
                 epochs=200,
                 early_stopping=None,
                 verbose=False,
                 save_best=True,
                 weight_path=None,
                 as_model=False,
                 monitor='val_acc',
                 early_stop_metric='val_loss',
                 callbacks=None,
                 **kwargs):
        """
            Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`.
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: bool
            Whether to show the training details. (default :obj: `None`)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using customized `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).
        """

        if not tf.__version__ >= '2.2.0':
            raise RuntimeError(
                f'This method is only work for tensorflow version >= 2.2.0.')

        # Check if model has been built
        if self.model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        model = self.model
        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks,
                                                      add_history=True,
                                                      add_progbar=True,
                                                      verbose=verbose,
                                                      epochs=epochs)
        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 0))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path

            makedirs_from_path(weight_path)

            if not weight_path.endswith('.h5'):
                weight_path += '.h5'

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)
        callbacks.set_model(model)

        # leave it blank for the future
        allowed_kwargs = set([])
        unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
        if unknown_kwargs:
            raise TypeError("Invalid keyword argument(s): %s" %
                            (unknown_kwargs, ))

        callbacks.on_train_begin()

        for epoch in range(epochs):
            callbacks.on_epoch_begin(epoch)

            callbacks.on_train_batch_begin(0)
            loss, accuracy = self.train_step(train_data)
            train_data.on_epoch_end()

            training_logs = {'loss': loss, 'acc': accuracy}
            callbacks.on_train_batch_end(0, training_logs)

            if validation:

                val_loss, val_accuracy = self.test_step(val_data)
                training_logs.update({
                    'val_loss': val_loss,
                    'val_acc': val_accuracy
                })
                val_data.on_epoch_end()

            callbacks.on_epoch_end(epoch, training_logs)

            if model.stop_training:
                break

        callbacks.on_train_end()

        if save_best:
            self.load(weight_path, as_model=as_model)
            remove_tf_weights(weight_path)

        return model.history
Exemple #15
0
    def train(self,
              idx_train,
              idx_val=None,
              epochs=200,
              early_stopping=None,
              verbose=0,
              save_best=True,
              weight_path=None,
              as_model=False,
              monitor='val_acc',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2, 3, 4}
                'verbose=0': not verbose; 
                'verbose=1': Progbar (one line, detailed); 
                'verbose=2': Progbar (one line, omitted); 
                'verbose=3': Progbar (multi line, detailed); 
                'verbose=4': Progbar (multi line, omitted); 
            (default :obj: 0)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using custom `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        raise_if_kwargs(kwargs)
        if not (isinstance(verbose, int) and 0 <= verbose <= 4):
            raise ValueError("'verbose=0': not verbose"
                             "'verbose=1': Progbar(one line, detailed), "
                             "'verbose=2': Progbar(one line, omitted), "
                             "'verbose=3': Progbar(multi line, detailed), "
                             "'verbose=4': Progbar(multi line, omitted), "
                             f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = History()
        callbacks.append(history)

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path
            else:
                self.weight_path = weight_path

            makedirs_from_filename(weight_path)

            if not weight_path.endswith(POSTFIX):
                weight_path = weight_path + POSTFIX

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False
        callbacks.on_train_begin()

        if verbose:
            stateful_metrics = {"acc", 'loss', 'val_acc', 'val_loss', 'time'}
            if verbose <= 2:
                progbar = Progbar(target=epochs,
                                  verbose=verbose,
                                  stateful_metrics=stateful_metrics)
            print("Training...")

        begin_time = time.perf_counter()
        try:
            for epoch in range(epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      verbose=verbose - 2,
                                      stateful_metrics=stateful_metrics)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                loss, accuracy = self.train_step(train_data)

                training_logs = {'loss': loss, 'acc': accuracy}
                if validation:
                    val_loss, val_accuracy = self.test_step(val_data)
                    training_logs.update({
                        'val_loss': val_loss,
                        'val_acc': val_accuracy
                    })
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), training_logs)
                callbacks.on_epoch_end(epoch, training_logs)

                train_data.on_epoch_end()

                if verbose:
                    time_passed = time.perf_counter() - begin_time
                    training_logs.update({'time': time_passed})
                    if verbose > 2:
                        print(f"Epoch {epoch+1}/{epochs}")
                        progbar.update(len(train_data), training_logs.items())
                    else:
                        progbar.update(epoch + 1, training_logs.items())

                if model.stop_training:
                    break

        finally:
            callbacks.on_train_end()
            # to avoid unexpected termination of the model
            if save_best:
                self.load(weight_path, as_model=as_model)
                self.remove_weights()

        return history
Exemple #16
0
def fit_loop(model,
             inputs,
             targets,
             sample_weights=None,
             batch_size=None,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_inputs=None,
             val_targets=None,
             val_sample_weights=None,
             shuffle=True,
             callback_metrics=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """Abstract fit function for arrays of data.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      sample_weights: Optional list of sample weight arrays.
      batch_size: Integer batch size or None if unknown.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_inputs: List of input arrays.
      val_targets: List of target arrays.
      val_sample_weights: Optional list of sample weight arrays.
      shuffle: Whether to shuffle the data at the beginning of each epoch
      callback_metrics: List of strings, the display names of the metrics
          passed to the callbacks. They should be the
          concatenation of list the display names of the outputs of
           `f` and the list of display names of the outputs of `f_val`.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    model._make_train_function()
    f = model.train_function

    sample_weights = sample_weights or []
    val_sample_weights = val_sample_weights or []
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = inputs + targets + sample_weights + [1]
        if val_inputs:
            val_ins = val_inputs + val_targets + val_sample_weights + [1]
    else:
        ins = inputs + targets + sample_weights
        if val_inputs:
            val_ins = val_inputs + val_targets + val_sample_weights
    if not val_inputs:
        val_ins = []

    do_validation = False
    if val_inputs:
        do_validation = True
        if (steps_per_epoch is None and verbose and inputs
                and hasattr(inputs[0], 'shape')
                and hasattr(val_inputs[0], 'shape')):
            print('Train on %d samples, validate on %d samples' %
                  (inputs[0].shape[0], val_inputs[0].shape[0]))
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')

    out_labels = model.metrics_names
    if do_validation:
        callback_metrics = copy.copy(out_labels) + [
            'val_' + n for n in out_labels
        ]
    else:
        callback_metrics = copy.copy(out_labels)

    num_train_samples = training_utils.check_num_samples(
        ins, batch_size, steps_per_epoch, 'steps_per_epoch')
    if num_train_samples is not None:
        index_array = np.arange(num_train_samples)

    model.history = cbks.History()
    all_callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        if steps_per_epoch is not None:
            count_mode = 'steps'
        else:
            count_mode = 'samples'
        all_callbacks.append(
            cbks.ProgbarLogger(count_mode,
                               stateful_metrics=model.stateful_metric_names))
    all_callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(all_callbacks)
    out_labels = out_labels or []

    # it's possible to callback a different model than self
    # (used by Sequential models)
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model

    callbacks.set_model(callback_model)

    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': num_train_samples,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks.on_train_begin()
    callback_model.stop_training = False
    for cbk in callbacks:
        cbk.validation_data = val_ins

    # To prevent a slowdown, we find beforehand the arrays that need conversion.
    feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights
    indices_for_conversion_to_dense = []
    for i in range(len(feed)):
        if issparse is not None and issparse(
                ins[i]) and not K.is_sparse(feed[i]):
            indices_for_conversion_to_dense.append(i)

    for epoch in range(initial_epoch, epochs):
        # Reset stateful metrics
        for m in model.stateful_metric_functions:
            m.reset_states()
        # Update callbacks
        callbacks.on_epoch_begin(epoch)
        epoch_logs = {}
        if steps_per_epoch is not None:
            for step_index in range(steps_per_epoch):
                batch_logs = {}
                batch_logs['batch'] = step_index
                batch_logs['size'] = 1
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = f(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(step_index, batch_logs)
                if callback_model.stop_training:
                    break

            if do_validation:
                val_outs = test_loop(model,
                                     val_inputs,
                                     val_targets,
                                     sample_weights=val_sample_weights,
                                     batch_size=batch_size,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o
        else:
            if shuffle == 'batch':
                index_array = training_utils.batch_shuffle(
                    index_array, batch_size)
            elif shuffle:
                np.random.shuffle(index_array)

            batches = make_batches(num_train_samples, batch_size)

            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]
                try:
                    if isinstance(ins[-1], int):
                        # Do not slice the training phase flag.
                        ins_batch = slice_arrays(ins[:-1],
                                                 batch_ids) + [ins[-1]]
                    else:
                        ins_batch = slice_arrays(ins, batch_ids)
                except TypeError:
                    raise TypeError('TypeError while preparing batch. '
                                    'If using HDF5 input data, '
                                    'pass shuffle="batch".')
                batch_logs = {}
                batch_logs['batch'] = batch_index
                batch_logs['size'] = len(batch_ids)
                callbacks.on_batch_begin(batch_index, batch_logs)
                for i in indices_for_conversion_to_dense:
                    ins_batch[i] = ins_batch[i].toarray()

                outs = f(ins_batch)
                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)
                if callback_model.stop_training:
                    break

                if batch_index == len(batches) - 1:  # Last batch.
                    if do_validation:
                        val_outs = test_loop(model,
                                             val_inputs,
                                             val_targets,
                                             sample_weights=val_sample_weights,
                                             batch_size=batch_size,
                                             verbose=0)
                        if not isinstance(val_outs, list):
                            val_outs = [val_outs]
                        # Same labels assumed.
                        for l, o in zip(out_labels, val_outs):
                            epoch_logs['val_' + l] = o
        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break
    callbacks.on_train_end()
    return model.history
    def train(self,
              train_data,
              val_data=None,
              epochs=200,
              early_stopping=None,
              verbose=1,
              save_best=True,
              ckpt_path=None,
              as_model=False,
              monitor='val_accuracy',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `train_data` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        train_data: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of objects (or sequence) that will be used during training.
        val_data: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of objects (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. 
            (default :obj: `None`, i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2, 3, 4}
            'verbose=0': not verbose;
            'verbose=1': Progbar (one line, detailed);
            'verbose=2': Progbar (one line, omitted);
            'verbose=3': Progbar (multi line, detailed);
            'verbose=4': Progbar (multi line, omitted);
            (default :obj: 1)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        ckpt_path: String or None
            The path of saved weights/model. 
            (default to current path.)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using custom `layer` or `loss` and so on.
        monitor: String
            One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, 
            it determines which metric will be used for `save_best`. 
            (default :obj: `val_accuracy`)
        early_stop_metric: String
            One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, 
            it determines which metric will be used for early stopping. 
            (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        raise_if_kwargs(kwargs)
        if not (isinstance(verbose, int) and 0 <= verbose <= 4):
            raise ValueError("'verbose=0': not verbose"
                             "'verbose=1': Progbar(one line, detailed), "
                             "'verbose=2': Progbar(one line, omitted), "
                             "'verbose=3': Progbar(multi line, detailed), "
                             "'verbose=4': Progbar(multi line, omitted), "
                             f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        metrics_names = getattr(model, "metrics_names", None)
        # FIXME: This would return '[]' for tensorflow>=2.2.0
        # See <https://github.com/tensorflow/tensorflow/issues/37990>
        # metrics_names = ['loss', 'accuracy']
        if not metrics_names:
            raise RuntimeError(f"Please specify the attribute 'metrics_names' for the model.")
        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        self.train_data = train_data

        validation = val_data is not None

        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            self.val_data = val_data
            metrics_names = metrics_names + ["val_" + metric for metric in metrics_names]

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = History()
        callbacks.append(history)

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not ckpt_path:
                ckpt_path = self.ckpt_path
            else:
                self.ckpt_path = ckpt_path

            makedirs_from_filepath(ckpt_path)

            if not ckpt_path.endswith(gg.file_ext()):
                ckpt_path = ckpt_path + gg.file_ext()

            if monitor not in metrics_names:
                monitor = metrics_names[-1]
                warnings.warn(f"'{monitor}' are not included in the metrics names. default to '{monitor}'.",
                              UserWarning)

            mc_callback = ModelCheckpoint(ckpt_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False

        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=epochs,
                                  width=20,
                                  verbose=verbose)
            print("Training...")

        logs = BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=20,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()

                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            self.load(ckpt_path, as_model=as_model)
        finally:
            # to avoid unexpected termination of the model
            self.remove_weights()

        return history
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            auto_switch=True,
            retry_fit=True,
            absorb=True,
            train_after_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            revert_after_fit=False):
        """
        Custom fit function for the context model

        auto_switch:        Enable/disable autonomous context switching
        train_after_switch:
        retry_fit:          Locate the next fitting context by re-performing fit.
        absorb:             Reset the switch sequence counter upon successful training.
                            This is mainly used to maintain switch sequencing for temporally-extended tasks
        revert_after_fit    This is a debug parameter to revert weights after performing a fit. This is used
                            to calculate the context deltas without incorrectly learning while auto switching
                            is disabled
        """

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            self.initialize_fit()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched_during_epoch = False  # Indicate if the model has attempted at least one switch during this epoch
                switched = True  # Indicate if the model switched on the most recent fit iteration
                weights = backend.batch_get_value(self.trainable_variables)
                # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits
                while switched and (retry_fit or not switched_during_epoch):
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)

                    # Perform a fit call
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, auto_switch, absorb, retry_fit, verbose)
                        switched_during_epoch |= switched

                        # If a switch occurred, we need to restore the weights
                        if switched or (switched_during_epoch
                                        and not train_after_switch
                                        ) or revert_after_fit:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
Exemple #19
0
def train_model(name,
                g_train,
                d_train,
                sampler,
                generator,
                samples_per_epoch,
                epochs,
                z_dim=100,
                verbose=1,
                callbacks=[],
                saver=None):
    """
    Main training loop.

    Modified version of fit_generator.
    """
    self = {}
    epoch = 0
    counter = 0
    out_labels = ['g_loss', 'd_loss', 'd_loss_fake', 'd_loss_legit',
                  'time']  # self.metrics_names
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    history = cbks.History()
    callbacks = [cbks.BaseLogger()] + callbacks + [history]
    if verbose:
        callbacks += [cbks.ProgbarLogger()]
    callbacks = cbks.CallbackList(callbacks)

    callback_params = {
        'epochs': epochs,
        'samples': samples_per_epoch,
        'verbose': verbose,
        'metrics': callback_metrics,
    }
    callbacks.set_params(callback_params)

    callbacks.on_train_begin()

    while epoch < epochs:
        callbacks.on_epoch_begin(epoch)
        samples_seen = 0
        batch_index = 0
        while samples_seen < samples_per_epoch:
            z, x = next(generator)
            # build batch logs
            batch_logs = {}
            if type(x) is list:
                batch_size = len(x[0])
            elif type(x) is dict:
                batch_size = len(list(x.values())[0])
            else:
                batch_size = len(x)
            batch_logs['batch'] = batch_index
            batch_logs['size'] = batch_size
            callbacks.on_batch_begin(batch_index, batch_logs)

            t1 = time.time()
            d_losses = d_train(x, z, counter)
            z, x = next(generator)
            g_loss, samples, xs = g_train(x, z, counter)
            outs = (g_loss, ) + d_losses + (time.time() - t1, )
            counter += 1

            # save samples
            if batch_index % 100 == 0:
                join_image = np.zeros_like(
                    np.concatenate([samples[:64], xs[:64]], axis=0))
                for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])):
                    join_image[j * 2] = i1
                    join_image[j * 2 + 1] = i2
                save_images(
                    join_image, [8 * 2, 8],
                    './outputs/samples_%s/train_%s_%s.png' %
                    (name, epoch, batch_index))

                samples, xs = sampler(z, x)
                join_image = np.zeros_like(
                    np.concatenate([samples[:64], xs[:64]], axis=0))
                for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])):
                    join_image[j * 2] = i1
                    join_image[j * 2 + 1] = i2
                save_images(
                    join_image, [8 * 2, 8],
                    './outputs/samples_%s/test_%s_%s.png' %
                    (name, epoch, batch_index))

            for l, o in zip(out_labels, outs):
                batch_logs[l] = o

            callbacks.on_batch_end(batch_index, batch_logs)

            # construct epoch logs
            epoch_logs = {}
            batch_index += 1
            samples_seen += batch_size

        if saver is not None:
            saver(epoch)

        callbacks.on_epoch_end(epoch, epoch_logs)
        epoch += 1

    # _stop.set()
    callbacks.on_train_end()