Exemplo n.º 1
0
 def test_return_all_elements(self):
     for size in [100, 10000, 100000]:
         shuffled_dataset = list(shuffled_iterator(range(size)))
         self.assertNotEqual(
             shuffled_dataset[:10], list(range(10)),
             'It is highly unlikely that the original order is preserved')
         self.assertSetEqual(set(shuffled_dataset), set(range(size)),
                             f'Some returned elements are missing.')
Exemplo n.º 2
0
    def minibatch_iterator(
        self,
        tensorized_data: Iterator[Tuple[TTensorizedDatapoint,
                                        Optional[TRawDatapoint]]],
        device: Union[str, torch.device],
        max_minibatch_size: int,
        yield_partial_minibatches: bool = True,
        shuffle_input: bool = False,
        parallelize: bool = True,
    ) -> Iterator[Tuple[Dict[str, Any], List[Optional[TRawDatapoint]]]]:
        """
        An iterator that yields minibatches to be consumed by a neural module.
        :param tensorized_data: An iterator of tensorized data. Commonly that's the output of tensorize_dataset()
        :param device: The device on which the tensorized data will be stored.
        :param max_minibatch_size: the maximum size of the minibatch.
        :param yield_partial_minibatches: If true, yield partial minibatches, i.e. minibatches that do not
            reach the `max_minibatch_size` and the `extend_minibatch_with` did not consider full.
            Users might want to set this to False, when training.
        :param shuffle_input: Should the `tensorized_data` be shuffled? (e.g. during training)
        :param parallelize: if True, minibatching will be parallelized. This may make debugging harder.
        :return: an iterator that yield tuples, with the minibatch data and the raw data points (if they are present in `tensorized_data`).
        """
        assert self.__metadata_initialized, "Metadata has not been initialized."

        if shuffle_input:
            tensorized_data = shuffled_iterator(tensorized_data,
                                                buffer_size=500)

        unfinalized_minibatches = ThreadedIterator(
            self.__iterate_unfinalized_minibatches(tensorized_data,
                                                   max_minibatch_size,
                                                   yield_partial_minibatches),
            enabled=parallelize,
        )
        yield from ThreadedIterator(
            ((self.finalize_minibatch(d[0], device), d[1])
             for d in unfinalized_minibatches),
            enabled=parallelize,
        )
Exemplo n.º 3
0
    def train(self,
              training_data: Iterable[InputData],
              validation_data: Iterable[InputData],
              show_progress_bar: bool = True,
              patience: int = 5,
              initialize_metadata: bool = True,
              exponential_running_average_factor: float = 0.97,
              get_parameters_to_freeze: Optional[Callable[[], Set]] = None,
              parallel_minibatch_creation: bool = False,
              device: Optional[Union[str, torch.device]] = None) -> None:
        """
        The training-validation loop for `BaseComponent`s.

        :param training_data: An iterable that each iteration yields the full training data.
        :param validation_data: An iterable that each iteration yields the full validation data.
        :param show_progress_bar: Show a progress bar
        :param patience: The number of iterations before early stopping kicks in.
        :param initialize_metadata: If true, initialize the metadata from the training_data. Otherwise,
            assume that the model that is being trained has its metadata already initialized.
        :param exponential_running_average_factor: The factor of the running average of the training loss
            displayed in the progress bar.
        :param get_parameters_to_freeze: The (optional) callable that returns the set of parameters to freeze during training.
        :param parallel_minibatch_creation: If True the minibatches will be created in a separate thread.
        """
        if initialize_metadata:
            self.__load_metadata(training_data)

        self.LOGGER.info('Model has %s parameters',
                         self.__model.num_parameters())
        self.LOGGER.debug('Data Tensorization Started...')

        def data_to_tensor_iterator(data):
            for datapoint in data:
                tensorized_datapoint = self.__model.load_data_from_sample(
                    datapoint)
                if tensorized_datapoint is not None:
                    yield tensorized_datapoint

        def training_tensors():
            yield from ThreadedIterator(
                original_iterator=data_to_tensor_iterator(training_data),
                max_queue_size=10 * self.__minibatch_size)

        def validation_tensors():
            yield from ThreadedIterator(
                original_iterator=data_to_tensor_iterator(validation_data),
                max_queue_size=10 * self.__minibatch_size)

        def minibatch_iterator(
                data_iterator: Iterator[TensorizedData],
                return_partial_minibatches: bool = False) -> Tuple[Dict, int]:
            while True:
                mb_data, batch_is_full, num_elements = self.__model.create_minibatch(
                    data_iterator, max_num_items=self.__minibatch_size)
                if num_elements == 0:
                    break
                elif not batch_is_full and not return_partial_minibatches:
                    break  # Do not return partial minibatches when the iterator is exhausted.
                else:
                    yield mb_data, num_elements

        if device is None:
            device = torch.device(
                'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.__model.to(device)
        self.LOGGER.info('Using %s for training.' % device)

        if get_parameters_to_freeze is None:
            get_parameters_to_freeze = lambda: set()
        trainable_parameters = set(
            self.__model.parameters()) - get_parameters_to_freeze()
        optimizer = self.__create_optimizer(trainable_parameters)
        scheduler = None if self.__create_scheduler is None else self.__create_scheduler(
            optimizer)

        for hook in self.__training_start_hooks:
            hook(self.__model, optimizer)

        best_loss = float('inf')  # type: float
        num_epochs_not_improved = 0  # type: int
        for epoch in range(self.__max_num_epochs):
            self.__model.train()

            data_iter = shuffled_iterator(training_tensors())
            sum_epoch_loss = 0.0
            running_avg_loss = 0.0
            num_minibatches = 0
            num_samples = 0

            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Training',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar:
                for step_idx, (mb_data, num_elements) in enumerate(
                        ThreadedIterator(minibatch_iterator(
                            data_iter, return_partial_minibatches=False),
                                         enabled=parallel_minibatch_creation)):
                    optimizer.zero_grad()
                    mb_loss = self.__model(**mb_data)
                    mb_loss.backward()

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step(epoch_idx=epoch, epoch_step=step_idx)

                    loss = float(mb_loss.cpu())
                    if math.isnan(loss):
                        raise Exception('Training Loss has a NaN value.')

                    sum_epoch_loss += loss
                    num_minibatches += 1
                    num_samples += num_elements

                    if num_minibatches == 1:  # First minibatch
                        running_avg_loss = loss
                    else:
                        running_avg_loss = exponential_running_average_factor * running_avg_loss + (
                            1 - exponential_running_average_factor) * loss
                    progress_bar.update()
                    progress_bar.set_postfix(Loss=f'{running_avg_loss:.2f}')

            elapsed_time = time.time() - start_time  # type: float
            self.LOGGER.info('Training complete in %.1fsec [%.2f samples/sec]',
                             elapsed_time, (num_samples / elapsed_time))
            assert num_minibatches > 0, 'No training minibatches were created. The minibatch size may be too large or the training dataset size too small.'
            self.LOGGER.info('Epoch %i: Avg Train Loss %.2f', epoch + 1,
                             sum_epoch_loss / num_minibatches)
            train_metrics = self.__model.report_metrics()
            for epoch_hook in self.__train_epoch_end_hooks:
                epoch_hook(self.__model, epoch, train_metrics)
            if len(train_metrics) > 0:
                self.LOGGER.info('Training Metrics: %s',
                                 json.dumps(train_metrics, indent=2))

            # Now do validation!
            self.__model.eval()
            data_iter = validation_tensors()
            sum_epoch_loss = 0
            num_minibatches = 0
            num_samples = 0
            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Validation',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar, torch.no_grad():
                for mb_data, num_elements in ThreadedIterator(
                        minibatch_iterator(data_iter,
                                           return_partial_minibatches=True),
                        enabled=parallel_minibatch_creation):
                    mb_loss = self.__model(**mb_data)

                    loss = float(mb_loss.cpu())
                    if math.isnan(loss):
                        raise Exception('Validation Loss has a NaN value.')

                    sum_epoch_loss += loss
                    num_minibatches += 1
                    num_samples += num_elements

                    progress_bar.update()
                    progress_bar.set_postfix(
                        Loss=f'{sum_epoch_loss / num_minibatches:.2f}')

            elapsed_time = time.time() - start_time
            assert num_samples > 0, 'No validation data was found.'
            validation_loss = sum_epoch_loss / num_minibatches
            self.LOGGER.info(
                'Validation complete in %.1fsec [%.2f samples/sec]',
                elapsed_time, (num_samples / elapsed_time))
            self.LOGGER.info('Epoch %i: Avg Valid Loss %.2f', epoch + 1,
                             validation_loss)
            validation_metrics = self.__model.report_metrics()
            for epoch_hook in self.__validation_epoch_end_hooks:
                epoch_hook(self.__model, epoch, validation_metrics)
            if len(validation_metrics) > 0:
                self.LOGGER.info('Validation Metrics: %s',
                                 json.dumps(validation_metrics, indent=2))

            if validation_loss < best_loss:
                self.LOGGER.info('Best loss so far --- Saving model.')
                num_epochs_not_improved = 0
                self.__save_current_model()
                best_loss = validation_loss
            else:
                num_epochs_not_improved += 1
                if num_epochs_not_improved > patience:
                    self.LOGGER.warning(
                        'After %s epochs loss has not improved. Stopping.',
                        num_epochs_not_improved)
                    break

        # Restore the best model that was found.
        self.restore_model()
Exemplo n.º 4
0
    def train(self,
              training_data: Iterable[InputData],
              validation_data: Iterable[InputData],
              show_progress_bar: bool = True,
              patience: int = 5,
              initialize_metadata: bool = True,
              exponential_running_average_factor: float = 0.97,
              parameters_to_freeze: Optional[Set] = None) -> None:
        """
        The training-validation loop for `BaseComponent`s.

        :param training_data: An iterable that each iteration yields the full training data.
        :param validation_data: An iterable that each iteration yields the full validation data.
        :param show_progress_bar: Show a progress bar
        :param patience: The number of iterations before early stopping kicks in.
        :param initialize_metadata: If true, initialize the metadata from the training_data. Otherwise,
            assume that the model that is being trained has its metadata already initialized.
        :param exponential_running_average_factor: The factor of the running average of the training loss
            displayed in the progress bar.
        :param parameters_to_freeze: The (optional) set of parameters to freeze during training.
        """
        if initialize_metadata:
            self.__load_metadata(training_data)

        self.LOGGER.info('Model has %s parameters',
                         self.__model.num_parameters())
        self.LOGGER.debug('Data Tensorization Started...')

        def training_tensors():
            yield from ThreadedIterator(
                original_iterator=(self.__model.load_data_from_sample(d)
                                   for d in training_data),
                max_queue_size=10 * self.__minibatch_size)

        def validation_tensors():
            yield from ThreadedIterator(
                original_iterator=(self.__model.load_data_from_sample(d)
                                   for d in validation_data),
                max_queue_size=10 * self.__minibatch_size)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.LOGGER.info('Using %s for training.' % device)
        if torch.cuda.is_available():
            self.__model.cuda()
        else:
            self.__model.cpu()

        if parameters_to_freeze is None:
            parameters_to_freeze = set()
        trainable_parameters = set(
            self.__model.parameters()) - parameters_to_freeze
        optimizer = self.__create_optimizer(trainable_parameters)

        best_loss = float('inf')  # type: float
        num_epochs_not_improved = 0  # type: int
        for epoch in range(self.__max_num_epochs):
            self.__model.train()

            data_iter = shuffled_iterator(training_tensors())
            sum_epoch_loss = 0.0
            running_avg_loss = 0.0
            num_minibatches = 0
            num_samples = 0

            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Training',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar:
                while True:
                    mb_data, data_iterator_exhausted, num_elements = self.__model.create_minibatch(
                        data_iter, max_num_items=self.__minibatch_size)
                    if data_iterator_exhausted or num_elements == 0:
                        break  # Do not consider half-full or empty minibatches
                    optimizer.zero_grad()
                    mb_loss = self.__model(**mb_data)
                    mb_loss.backward()

                    optimizer.step()
                    num_minibatches += 1
                    num_samples += num_elements
                    sum_epoch_loss += float(mb_loss.cpu())
                    if num_minibatches == 1:  # First minibatch
                        running_avg_loss = float(mb_loss.cpu())
                    else:
                        running_avg_loss = exponential_running_average_factor * running_avg_loss + (
                            1 - exponential_running_average_factor) * float(
                                mb_loss.cpu())
                    progress_bar.update()
                    progress_bar.set_postfix(Loss=f'{running_avg_loss:.2f}')

            elapsed_time = time.time() - start_time  # type: float
            self.LOGGER.info('Training complete in %.1fsec [%.2f samples/sec]',
                             elapsed_time, (num_samples / elapsed_time))
            assert num_minibatches > 0, 'No training minibatches were created. The minibatch size may be too large or the training dataset size too small.'
            self.LOGGER.info('Epoch %i: Avg Train Loss %.2f', epoch + 1,
                             sum_epoch_loss / num_minibatches)
            train_metrics = self.__model.report_metrics()
            for epoch_hook in self.__train_epoch_end_hooks:
                epoch_hook(self.__model, epoch, train_metrics)
            if len(train_metrics) > 0:
                self.LOGGER.info('Training Metrics: %s',
                                 json.dumps(train_metrics, indent=2))

            # Now do validation!
            self.__model.eval()
            data_iter = validation_tensors()
            sum_epoch_loss = 0
            num_minibatches = 0
            num_samples = 0
            start_time = time.time()
            self.__model.reset_metrics()
            with tqdm(desc='Validation',
                      disable=not show_progress_bar,
                      leave=False) as progress_bar, torch.no_grad():
                while True:
                    mb_data, data_iterator_exhausted, num_elements = self.__model.create_minibatch(
                        data_iter, max_num_items=self.__minibatch_size)
                    if num_elements == 0:
                        break  # No more elements could be found in the data_iter.
                    mb_loss = self.__model(**mb_data)
                    num_minibatches += 1
                    num_samples += num_elements
                    sum_epoch_loss += float(mb_loss.cpu())
                    progress_bar.update()
                    progress_bar.set_postfix(
                        Loss=f'{sum_epoch_loss / num_minibatches:.2f}')
                    if data_iterator_exhausted:
                        break  # No more elements in the data iterator

            elapsed_time = time.time() - start_time  # type: float
            assert num_samples > 0, 'No validation data was found.'
            validation_loss = sum_epoch_loss / num_minibatches
            self.LOGGER.info(
                'Validation complete in %.1fsec [%.2f samples/sec]',
                elapsed_time, (num_samples / elapsed_time))
            self.LOGGER.info('Epoch %i: Avg Valid Loss %.2f', epoch + 1,
                             validation_loss)
            validation_metrics = self.__model.report_metrics()
            for epoch_hook in self.__validation_epoch_end_hooks:
                epoch_hook(self.__model, epoch, train_metrics)
            if len(validation_metrics) > 0:
                self.LOGGER.info('Validation Metrics: %s',
                                 json.dumps(validation_metrics, indent=2))

            if validation_loss < best_loss:
                self.LOGGER.info('Best loss so far --- Saving model.')
                num_epochs_not_improved = 0
                self.__save_current_model()
                best_loss = validation_loss
            else:
                num_epochs_not_improved += 1
                if num_epochs_not_improved > patience:
                    self.LOGGER.warning(
                        'After %s epochs loss has not improved. Stopping.',
                        num_epochs_not_improved)
                    break

        # Restore the best model that was found.
        self.restore_model()