Ejemplo n.º 1
0
    def run_one_mini_batch(cls, model, criterion, optimizer, input, target,
                           **kwargs):
        """See parent method for documentation

        Extra-Parameters
        ----------
        optimizer : torch.optim
            The optimizer used to perform the weight update.
        """
        # Compute output
        output = model(input, target_size=target.shape[0])

        # Compute and record the loss
        loss = criterion(output, target)
        MetricLogger().update(key='loss', value=loss.item(), n=len(input))

        # Compute and record the accuracy
        acc = accuracy(output.data, target.data, topk=(1, ))[0]
        MetricLogger().update(key='accuracy', value=acc[0], n=len(input))
        # TODO check if n is correct

        # Reset gradient
        optimizer.zero_grad()
        # Compute gradients
        loss.backward()
        # Perform a step by updating the weights
        optimizer.step()
Ejemplo n.º 2
0
    def start_of_the_epoch(cls, model, **kwargs):
        """See parent method for documentation

        Extra-Parameters
        ----------
        model : torch.nn.module
            The network model being used.
        """
        model.train()

        MetricLogger().add_scalar_meter(tag=cls.main_metric())
        MetricLogger().add_scalar_meter(tag='loss')
Ejemplo n.º 3
0
    def end_of_the_epoch(cls,
                         data_loader,
                         epoch,
                         logging_label,
                         multi_run_label="",
                         current_log_folder=None,
                         **kwargs):
        """See parent method for documentation

        Extra-Parameters
        ----------
        data_loader : torch.utils.data.DataLoader
            The dataloader of the current set.
        epoch : int
            Number of the epoch (for logging purposes).
        logging_label : string
            Label for logging purposes. Typically 'train', 'test' or 'val'.
            It's prepended to the logging output path and messages.
        """
        # Make and log to TB the confusion matrix
        cm = MetricLogger()['confusion_matrix'].make_heatmap(
            data_loader.dataset.classes)
        TBWriter().save_image(tag=logging_label + '/confusion_matrix' +
                              multi_run_label,
                              image=cm,
                              global_step=epoch)

        # Generate a classification report for each epoch
        cr = MetricLogger()['confusion_matrix'].get_classification_report(
            data_loader.dataset.classes)
        multi_tag = ''
        if len(multi_run_label) > 0:
            multi_tag = f' and run {multi_run_label}'
        TBWriter().add_text(
            tag='Classification Report for epoch {}{}\n'.format(
                epoch, multi_tag),
            text_string='\n' + cr,
            global_step=epoch)

        # only during testing
        if logging_label == 'test':
            multi_tag = ''
            if len(multi_run_label) > 0:
                multi_tag = ' run{}'.format(multi_run_label)
            # save the clasification output as a csv
            MetricLogger()['classification_results{}'.format(
                multi_run_label)].save_csv(output_folder=current_log_folder,
                                           multi_run_label=multi_run_label)
            report = MetricLogger()['classification_results{}'.format(
                multi_run_label)].get_report()
            TBWriter().add_text(
                tag='Classification per test file {}\n'.format(multi_tag),
                text_string='\n' + report)
Ejemplo n.º 4
0
    def start_of_the_epoch(cls, model, num_classes, data_loader, multi_run_label, **kwargs):
        """See parent method for documentation

        Extra-Parameters
        ----------
        model : torch.nn.module
            The network model being used.
        num_classes : int
            Number of classes in the dataset
        """
        model.eval()
        MetricLogger().add_scalar_meter(tag=cls.main_metric())
        MetricLogger().add_scalar_meter(tag='loss')
        MetricLogger().add_confusion_matrix_meter(tag='confusion_matrix', num_classes=num_classes)
        MetricLogger().add_classification_results_meter(tag='classification_results', file_list=data_loader.dataset.config['file_names'])
Ejemplo n.º 5
0
    def start_of_the_epoch(cls, model, num_classes, **kwargs):
        """See parent method for documentation

        Extra-Parameters
        ----------
        model : torch.nn.module
            The network model being used.
        num_classes : int
            Number of classes in the dataset
        """
        model.eval()

        MetricLogger().add_scalar_meter(tag=cls.main_metric())
        MetricLogger().add_scalar_meter(tag='loss')
        MetricLogger().add_confusion_matrix_meter(tag='confusion_matrix',
                                                  num_classes=num_classes)
Ejemplo n.º 6
0
    def run_one_mini_batch(cls, model, criterion, input, target, multi_run_label, **kwargs):
        """See parent method for documentation"""
        # Compute output
        output = model(input, target_size=target.shape[0])

        # Compute and record the loss
        loss = criterion(output, target)
        MetricLogger().update(key='loss', value=loss.item(), n=len(input))

        # Compute and record the accuracy
        acc = accuracy(output.data, target.data, topk=(1,))[0]
        MetricLogger().update(key='accuracy', value=acc[0], n=len(input))

        # Update the confusion matrix
        MetricLogger().update(key='confusion_matrix', p=np.argmax(output.data.cpu().numpy(), axis=1), t=target.cpu().numpy())
        # Update the classification results
        MetricLogger().update(key='classification_results', p=np.argmax(output.data.cpu().numpy(), axis=1), t=target.cpu().numpy(),
                          f_ind=input.file_name_ind.cpu().numpy())
Ejemplo n.º 7
0
    def run_one_mini_batch(cls, model, criterion, input, target, **kwargs):
        """See parent method for documentation"""
        # Compute output
        output = model(input)

        # Unpack the target
        target = target['category_id']

        # Compute and record the loss
        loss = criterion(output, target)
        MetricLogger().update(key='loss', value=loss.item(), n=len(input))

        # Compute and record the accuracy
        acc = accuracy(output.data, target.data, topk=(1, ))[0]
        MetricLogger().update(key='accuracy', value=acc[0], n=len(input))

        # Update the confusion matrix
        MetricLogger().update(key='confusion_matrix',
                              p=np.argmax(output.data.cpu().numpy(), axis=1),
                              t=target.cpu().numpy())
Ejemplo n.º 8
0
    def run(cls,
            data_loader,
            epoch,
            log_interval,
            logging_label,
            batch_lr_schedulers,
            run=None,
            **kwargs):
        """
        Training routine

        Parameters
        ----------
        data_loader : torch.utils.data.DataLoader
            The dataloader of the current set.
        epoch : int
            Number of the epoch (for logging purposes).
        log_interval : int
            Interval limiting the logging of mini-batches.
        logging_label : string
            Label for logging purposes. Typically 'train', 'test' or 'val'.
            It's prepended to the logging output path and messages.
        run : int
            Number of run, used in multi-run context to discriminate the different runs
        batch_lr_schedulers : list(torch.optim.lr_scheduler)
            List of lr schedulers to call step() on after every batch

        Returns
        ----------
        Main metric : float
            Main metric of the model on the evaluated split
        """
        # 'run' is injected in kwargs at runtime in RunMe.py IFF it is a multi-run event
        multi_run_label = f"_{run}" if run is not None else ""

        # Instantiate the counter
        MetricLogger().reset(postfix=multi_run_label)

        # Custom routine to run at the start of the epoch
        cls.start_of_the_epoch(data_loader=data_loader,
                               epoch=epoch,
                               logging_label=logging_label,
                               multi_run_label=multi_run_label,
                               **kwargs)

        # Iterate over whole training set
        end = time.time()
        pbar = tqdm(enumerate(data_loader),
                    total=len(data_loader),
                    unit='batch',
                    ncols=130,
                    leave=False)
        for batch_idx, data in pbar:
            input = data
            target = data.y

            # Measure data loading time
            data_time = time.time() - end

            # Moving data to GPU
            input, target = cls.move_to_device(input=input,
                                               target=target,
                                               **kwargs)

            cls.run_one_mini_batch(input=input,
                                   target=target,
                                   multi_run_label=multi_run_label,
                                   **kwargs)

            # Update the LR according to the scheduler
            for lr_scheduler in batch_lr_schedulers:
                lr_scheduler.step()

            # Add metrics to Tensorboard for the last mini-batch value
            for tag, meter in MetricLogger():
                if isinstance(meter, ScalarValue):
                    TBWriter().add_scalar(
                        tag=logging_label + '/mb_' + tag,
                        scalar_value=meter.value,
                        global_step=epoch * len(data_loader) + batch_idx)

            # Measure elapsed time for a mini-batch
            batch_time = time.time() - end
            end = time.time()

            # Log to console
            if batch_idx % log_interval == 0 and len(MetricLogger()) > 0:
                if cls.main_metric() + multi_run_label in MetricLogger():
                    mlogger = MetricLogger()[cls.main_metric()]
                elif "loss" + multi_run_label in MetricLogger():
                    mlogger = MetricLogger()["loss"]
                else:
                    raise AttributeError
                pbar.set_description(
                    f'{logging_label} epoch [{epoch}][{batch_idx}/{len(data_loader)}]'
                )
                pbar.set_postfix(Metric=f'{mlogger.global_avg:.3f}',
                                 Time=f'{batch_time:.3f}',
                                 Data=f'{data_time:.3f}')

        # Custom routine to run at the end of the epoch
        cls.end_of_the_epoch(data_loader=data_loader,
                             epoch=epoch,
                             logging_label=logging_label,
                             multi_run_label=multi_run_label,
                             **kwargs)

        # Add metrics to Tensorboard for the full-epoch value
        for tag, meter in MetricLogger():
            if isinstance(meter, ScalarValue):
                TBWriter().add_scalar(tag=logging_label + '/' + tag,
                                      scalar_value=meter.global_avg,
                                      global_step=epoch)

        if cls.main_metric() + multi_run_label in MetricLogger():
            return MetricLogger()[cls.main_metric()].global_avg
        else:
            return 0