def evaluate_generator(self, generator, verbose=1, steps=None, pass_state=False): """ Perform an evaluation loop on given data generator to evaluate metrics :param generator: The evaluation data generator (usually a pytorch DataLoader) :type generator: DataLoader :param verbose: If 1 use tqdm progress frontend, else display no training progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The dictionary containing final metrics :rtype: dict[str,any] """ state = { torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator } state.update(self.main_state) _callbacks = [] if verbose == 1: _callbacks.append(Tqdm('e')) self._test_loop(state, CallbackList(_callbacks), pass_state, self._load_batch_standard, steps) return state[torchbearer.METRICS]
def predict_generator(self, generator, verbose=1, steps=None, pass_state=False): """Perform a prediction loop on given data generator to predict labels :param generator: The prediction data generator (usually a pytorch DataLoader) :type generator: DataLoader :param verbose: If 1 use tqdm progress frontend, else display no training progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: Tensor of final predicted labels :rtype: torch.Tensor """ state = { torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator } state.update(self.main_state) _callbacks = [AggregatePredictions()] if verbose == 1: _callbacks.append(Tqdm('p')) self._test_loop(state, CallbackList(_callbacks), pass_state, self._load_batch_predict, steps) return state[torchbearer.FINAL_PREDICTIONS]
def _add_printer(callbacks, verbose, validation_label_letter='v'): """Static method used to add the printer callback to the given list for the given verbose level :param callbacks: The list to add to :type callbacks: list :param verbose: 2, 1 or 0, Most -> Least verbose :type verbose: int :param validation_label_letter: Pass to Tqdm :type validation_label_letter: str :return: The updated list :rtype: list """ if verbose >= 2: return [Tqdm(validation_label_letter=validation_label_letter)] + callbacks elif verbose >= 1: return [Tqdm(validation_label_letter=validation_label_letter, on_epoch=True)] + callbacks else: return callbacks
def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callbacks=[], validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False): """ Perform fitting of a model to given data generator :param generator: The training data generator (usually a pytorch DataLoader) :type generator: DataLoader :param train_steps: The number of training mini-batches to run per epoch :type train_steps: int :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once) :type epochs: int :param verbose: If 1 use tqdm progress frontend, else display no training progress :type verbose: int :param callbacks: The list of torchbearer callbacks to be called during training and validation :type callbacks: list :param validation_generator: The validation data generator (usually a pytorch DataLoader) :type validation_generator: DataLoader :param validation_steps: The number of validation mini-batches to run per epoch :type validation_steps: int :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs :type initial_epoch: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The final state context dictionary :rtype: dict[str,any] """ if verbose == 1: callbacks = [Tqdm()] + callbacks _callbacks = CallbackList(callbacks) # Get train and validation steps if validation_steps is None and validation_generator is not None: validation_steps = len(validation_generator) if train_steps is None or train_steps > len(generator): train_steps = len(generator) if not isinstance(train_steps, int): train_steps = int(train_steps) warnings.warn( "Number of training steps is not an int, converting to int") if not isinstance(epochs, int): if isinstance(epochs, float): epochs = int(epochs) warnings.warn("Number of epochs is a float, converting to int") else: warnings.warn( "Number of epochs is neither float nor int, setting to 0") epochs = 0 # Init state state = { torchbearer.MAX_EPOCHS: epochs, torchbearer.TRAIN_STEPS: train_steps, torchbearer.BATCH: 0, torchbearer.GENERATOR: generator, torchbearer.STOP_TRAINING: False } state.update(self.main_state) _callbacks.on_start(state) for state[torchbearer.EPOCH] in range(initial_epoch, epochs): _callbacks.on_start_epoch(state) state[torchbearer.TRAIN_ITERATOR] = iter( state[torchbearer.GENERATOR]) self.train() _callbacks.on_start_training(state) state[torchbearer.METRIC_LIST].reset(state) state[torchbearer.METRICS] = {} for state[torchbearer.BATCH] in range( 0, state[torchbearer.TRAIN_STEPS]): # Extract batch self._load_batch_standard('train', state) _callbacks.on_sample(state) # Zero grads state[torchbearer.OPTIMIZER].zero_grad() # Forward pass if pass_state: state[torchbearer.Y_PRED] = state[torchbearer.MODEL]( state[torchbearer.X], state=state) else: state[torchbearer.Y_PRED] = state[torchbearer.MODEL]( state[torchbearer.X]) _callbacks.on_forward(state) # Loss Calculation state[torchbearer.LOSS] = state[torchbearer.CRITERION]( state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]) _callbacks.on_criterion(state) state[torchbearer.METRICS] = state[ torchbearer.METRIC_LIST].process(state) # Backwards pass state[torchbearer.LOSS].backward() _callbacks.on_backward(state) # Update parameters state[torchbearer.OPTIMIZER].step() _callbacks.on_step_training(state) if state[torchbearer.STOP_TRAINING]: break state[torchbearer.METRICS].update( state[torchbearer.METRIC_LIST].process_final(state)) final_metrics = state[torchbearer.METRICS] _callbacks.on_end_training(state) # Validate if validation_generator is not None: state[torchbearer.VALIDATION_GENERATOR] = validation_generator state[torchbearer.VALIDATION_STEPS] = validation_steps self.eval() self._validate(state, _callbacks, pass_state) final_metrics.update(state[torchbearer.METRICS]) state[torchbearer.METRICS] = final_metrics _callbacks.on_end_epoch(state) if state[torchbearer.STOP_TRAINING]: break _callbacks.on_end(state) return state