def fit_dataset(self, dataset, steps_per_epoch=None, batch_size=32, epochs=1, verbose=1, callbacks=None, on_sample=None, on_scores=None): """Train the model on the given dataset for a given number of epochs. Arguments --------- dataset: Instance of `BaseDataset` that provides the data to train on. steps_per_epoch: int or None, number of gradient updates before considering an epoch has passed. If None it is set to be `len(dataset.train_data) / batch_size`. batch_size: int, number of samples per gradient update epochs: int, number of times to iterate `steps_per_epoch` times verbose: {0, >0}, whether to employ the progress bar Keras callback or not callbacks: list of Keras callbacks to be called during training on_sample: callable that accepts the sampler, idxs, w, scores on_scores: callable that accepts the sampler and scores """ try: if len(dataset.train_data) < batch_size: raise ValueError(("The model cannot be trained with " "batch_size > training set")) except RuntimeError as e: assert "no size" in str(e) # Set steps_per_epoch properly if steps_per_epoch is None: steps_per_epoch = len(dataset.train_data) // batch_size # Create the callbacks list self.history = History() callbacks = [BaseLogger()] + (callbacks or []) + [self.history] if verbose > 0: callbacks += [ProgbarLogger(count_mode="steps")] callbacks = CallbackList(callbacks) callbacks.set_model(self.original_model) callbacks.set_params({ "epochs": epochs, "steps": steps_per_epoch, "verbose": verbose, "do_validation": len(dataset.test_data) > 0, "metrics": self.metrics_names + ["val_" + n for n in self.metrics_names] }) # Create the sampler sampler = self.sampler(dataset, batch_size, steps_per_epoch, epochs) # Start the training loop epoch = 0 self.original_model.stop_training = False callbacks.on_train_begin() while epoch < epochs: callbacks.on_epoch_begin(epoch) for step in range(steps_per_epoch): batch_logs = {"batch": step, "size": batch_size} callbacks.on_batch_begin(step, batch_logs) # Importance sampling is done here idxs, (x, y), w = sampler.sample(batch_size) # Train on the sampled data loss, metrics, scores = self.model.train_batch(x, y, w) # Update the sampler sampler.update(idxs, scores) values = map(lambda x: x.mean(), [loss] + metrics) for l, o in zip(self.metrics_names, values): batch_logs[l] = o callbacks.on_batch_end(step, batch_logs) if on_scores is not None and hasattr(self, "_latest_scores"): on_scores(sampler, self._latest_scores) if on_sample is not None: on_sample(sampler, self._latest_sample_event["idxs"], self._latest_sample_event["w"], self._latest_sample_event["predicted_scores"]) if self.original_model.stop_training: break # Evaluate now that an epoch passed epoch_logs = {} if len(dataset.test_data) > 0: val = self.model.evaluate(*dataset.test_data[:], batch_size=batch_size) epoch_logs = { "val_" + l: o for l, o in zip(self.metrics_names, val) } callbacks.on_epoch_end(epoch, epoch_logs) if self.original_model.stop_training: break epoch += 1 callbacks.on_train_end() return self.history
def fit_generator(self, generator, n_steps_per_epoch, n_epochs=1, validation_data=None, n_validation_steps=None): """Train the network on batches of data generated from `generator` :param generator: a generator yielding batches indefinitely, where each batch is a tuple of (inputs, targets) :type generator: generator :param n_steps_per_epoch: number of batches to train on in one epoch :type n_steps_per_epoch: int :param n_epochs: number of epochs to train the model :type n_epochs: int :param validation_data: generator yielding batches to evaluate the loss on at the end of each epoch, where each batch is a tuple of (inputs, targets) :type validation_data: generator :param n_validation_steps: number of batches to evaluate on from `validation_data` :raises RuntimeError: if only one of `validation_data` and `n_validation_steps` are passed in """ default_callbacks = self._default_callbacks() callbacks = CallbackList(default_callbacks) self._assert_compiled() invalid_inputs = ( (validation_data is not None and n_validation_steps is None) or (n_validation_steps is not None and validation_data is None)) if invalid_inputs: msg = ('`validation_data` and `n_validation_steps` must both be ' 'passed, or neither.') raise RuntimeError(msg) if self.device: self.network.to(self.device) callbacks.set_params({ 'epochs': n_epochs, 'metrics': ['loss', 'val_loss'], 'steps': n_steps_per_epoch, 'verbose': True }) callbacks.set_model(self) callbacks.on_train_begin() for idx_epoch in range(n_epochs): if self.stop_training: break epoch_logs = {} callbacks.on_epoch_begin(idx_epoch) for idx_batch in range(n_steps_per_epoch): batch_logs = {'batch': idx_batch, 'size': 1} callbacks.on_batch_begin(idx_batch, batch_logs) inputs, targets = next(generator) loss = self.train_on_batch(inputs, targets) batch_logs['loss'] = loss callbacks.on_batch_end(idx_batch, batch_logs) if self.stop_training: break if validation_data: val_loss = self.evaluate_generator(validation_data, n_validation_steps) epoch_logs['val_loss'] = val_loss callbacks.on_epoch_end(idx_epoch, epoch_logs) callbacks.on_train_end()