Example #1
0
    def fit_generator(self, generator, epochs=1,
                      validation_data=None,
                      callbacks=None,
                      verbose=True):
        method = self._model.optimizer.method
        x0 = self._collect_weights()
        history = History()
        _callbacks = [BaseLogger(stateful_metrics=self._model.metrics_names)]
        _callbacks += (callbacks or []) + [history]
        callback_list = CallbackList(_callbacks)
        callback_list.set_model(self._model)
        callback_list.set_params({
            'epochs': epochs,
            'verbose': False,
            'metrics': list(self._model.metrics_names),
        })
        state = {
            'epoch': 0,
            'verbose': verbose,
            'callbacks': callback_list,
            'in_epoch': False,
            'epoch_logs': {},
        }
        min_options = {
            'maxiter': epochs,
            'maxfun': epochs*10,
            'ftol': 1e-10,
            'gtol': 1e-10,
            'eps': 1e-8,
        }

        val_generator = None
        if validation_data is not None:
            if isinstance(validation_data, keras.utils.Sequence):
                val_generator = validation_data
            elif isinstance(validation_data, tuple) and len(validation_data) == 2:
                val_generator = GeneratorWrapper(*validation_data)

        def on_iteration_end(xk):
            cb = state['callbacks']
            if val_generator is not None:
                self._validate(xk, val_generator, state)
            cb.on_epoch_end(state['epoch'], state['epoch_logs'])
            # if state['verbose']:
            #     epoch_logs = state['epoch_logs']
            #     print('epoch: ', state['epoch'],
            #           ', '.join([' {0}: {1:.3e}'.format(k, v) for k, v in epoch_logs.items()]))
            state['epoch'] += 1
            state['in_epoch'] = False
            state['epoch_logs'] = {}

        callback_list.on_train_begin()
        result = minimize(
            self._fun_generator, x0, method=method, jac=True, options=min_options,
            callback=on_iteration_end, args=(generator, state))
        self._update_weights(result['x'])
        callback_list.on_train_end()
        return history
def prepare_callbacks(g,
                      d,
                      callbacks,
                      n_epochs=1,
                      n_batches=1,
                      include_d_metrics=False):
    # all the callback stuff is from https://github.com/keras-team/keras/blob/master/keras/engine/training_generator.py
    # NOTE: see if saving the weights of d_on_g is enough
    # Prepare display labels.
    out_labels = g.metrics_names
    out_labels = [_replace_label_first_underscore(l) for l in out_labels]
    # we only want to validate on the output of g
    val_out_labels = ['val_' + n for n in out_labels if g.name in n]
    callback_metrics = out_labels + val_out_labels
    if include_d_metrics:
        d_metrics_names = d.metrics_names
        d_metrics_fake = ['d_training/' + l + '_fake' for l in d_metrics_names]
        d_metrics_real = ['d_training/' + l + '_real' for l in d_metrics_names]
        d_metrics_names = d_metrics_fake + d_metrics_real
        callback_metrics += d_metrics_names
    # prepare callbacks
    g.history = cbks.History()
    _callbacks = [cbks.BaseLogger(stateful_metrics=g.metrics_names[1:])]
    _callbacks += (callbacks or []) + [g.history]
    callbacks = CallbackList(_callbacks)

    # it's possible to callback a different model than self:
    callback_model = g._get_callback_model()

    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': n_epochs,
        'steps': n_batches,
        'verbose': 0,
        # 'do_validation': do_validation, to set when using validation data
        'metrics': callback_metrics,
    })
    if not include_d_metrics:
        d_metrics_fake, d_metrics_real = None, None
    return callbacks, out_labels, val_out_labels, d_metrics_fake, d_metrics_real
Example #3
0
    def fit_dataset(self,
                    dataset,
                    steps_per_epoch=None,
                    batch_size=32,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    on_sample=None,
                    on_scores=None):
        """Train the model on the given dataset for a given number of epochs.

        Arguments
        ---------
            dataset: Instance of `BaseDataset` that provides the data
                     to train on.
            steps_per_epoch: int or None, number of gradient updates before
                             considering an epoch has passed. If None it is set
                             to be `len(dataset.train_data) / batch_size`.
            batch_size: int, number of samples per gradient update
            epochs: int, number of times to iterate `steps_per_epoch` times
            verbose: {0, >0}, whether to employ the progress bar Keras
                     callback or not
            callbacks: list of Keras callbacks to be called during training
            on_sample: callable that accepts the sampler, idxs, w, scores
            on_scores: callable that accepts the sampler and scores
        """
        try:
            if len(dataset.train_data) < batch_size:
                raise ValueError(("The model cannot be trained with "
                                  "batch_size > training set"))
        except RuntimeError as e:
            assert "no size" in str(e)

        # Set steps_per_epoch properly
        if steps_per_epoch is None:
            steps_per_epoch = len(dataset.train_data) // batch_size

        # Create the callbacks list
        self.history = History()
        callbacks = [BaseLogger()] + (callbacks or []) + [self.history]
        if verbose > 0:
            callbacks += [ProgbarLogger(count_mode="steps")]
        callbacks = CallbackList(callbacks)
        callbacks.set_model(self.original_model)
        callbacks.set_params({
            "epochs":
            epochs,
            "steps":
            steps_per_epoch,
            "verbose":
            verbose,
            "do_validation":
            len(dataset.test_data) > 0,
            "metrics":
            self.metrics_names + ["val_" + n for n in self.metrics_names]
        })

        # Create the sampler
        sampler = self.sampler(dataset, batch_size, steps_per_epoch, epochs)

        # Start the training loop
        epoch = 0
        self.original_model.stop_training = False
        callbacks.on_train_begin()
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            for step in range(steps_per_epoch):
                batch_logs = {"batch": step, "size": batch_size}
                callbacks.on_batch_begin(step, batch_logs)

                # Importance sampling is done here
                idxs, (x, y), w = sampler.sample(batch_size)
                # Train on the sampled data
                loss, metrics, scores = self.model.train_batch(x, y, w)
                # Update the sampler
                sampler.update(idxs, scores)

                values = map(lambda x: x.mean(), [loss] + metrics)
                for l, o in zip(self.metrics_names, values):
                    batch_logs[l] = o
                callbacks.on_batch_end(step, batch_logs)

                if on_scores is not None and hasattr(self, "_latest_scores"):
                    on_scores(sampler, self._latest_scores)

                if on_sample is not None:
                    on_sample(sampler, self._latest_sample_event["idxs"],
                              self._latest_sample_event["w"],
                              self._latest_sample_event["predicted_scores"])

                if self.original_model.stop_training:
                    break

            # Evaluate now that an epoch passed
            epoch_logs = {}
            if len(dataset.test_data) > 0:
                val = self.model.evaluate(*dataset.test_data[:],
                                          batch_size=batch_size)
                epoch_logs = {
                    "val_" + l: o
                    for l, o in zip(self.metrics_names, val)
                }
            callbacks.on_epoch_end(epoch, epoch_logs)
            if self.original_model.stop_training:
                break
            epoch += 1
        callbacks.on_train_end()

        return self.history
Example #4
0
    def fit_generator(self,
                      generator,
                      n_steps_per_epoch,
                      n_epochs=1,
                      validation_data=None,
                      n_validation_steps=None):
        """Train the network on batches of data generated from `generator`

        :param generator: a generator yielding batches indefinitely, where each
         batch is a tuple of (inputs, targets)
        :type generator: generator
        :param n_steps_per_epoch: number of batches to train on in one epoch
        :type n_steps_per_epoch: int
        :param n_epochs: number of epochs to train the model
        :type n_epochs: int
        :param validation_data: generator yielding batches to evaluate the loss
         on at the end of each epoch, where each batch is a tuple of (inputs,
         targets)
        :type validation_data: generator
        :param n_validation_steps: number of batches to evaluate on from
         `validation_data`
        :raises RuntimeError: if only one of `validation_data` and
         `n_validation_steps` are passed in
        """

        default_callbacks = self._default_callbacks()
        callbacks = CallbackList(default_callbacks)

        self._assert_compiled()

        invalid_inputs = (
            (validation_data is not None and n_validation_steps is None)
            or (n_validation_steps is not None and validation_data is None))
        if invalid_inputs:
            msg = ('`validation_data` and `n_validation_steps` must both be '
                   'passed, or neither.')
            raise RuntimeError(msg)

        if self.device:
            self.network.to(self.device)

        callbacks.set_params({
            'epochs': n_epochs,
            'metrics': ['loss', 'val_loss'],
            'steps': n_steps_per_epoch,
            'verbose': True
        })
        callbacks.set_model(self)

        callbacks.on_train_begin()
        for idx_epoch in range(n_epochs):
            if self.stop_training:
                break

            epoch_logs = {}
            callbacks.on_epoch_begin(idx_epoch)

            for idx_batch in range(n_steps_per_epoch):
                batch_logs = {'batch': idx_batch, 'size': 1}
                callbacks.on_batch_begin(idx_batch, batch_logs)

                inputs, targets = next(generator)
                loss = self.train_on_batch(inputs, targets)

                batch_logs['loss'] = loss
                callbacks.on_batch_end(idx_batch, batch_logs)

                if self.stop_training:
                    break

            if validation_data:
                val_loss = self.evaluate_generator(validation_data,
                                                   n_validation_steps)
                epoch_logs['val_loss'] = val_loss
            callbacks.on_epoch_end(idx_epoch, epoch_logs)
        callbacks.on_train_end()