示例#1
0
def validation(valid_loader,
               model,
               criterion,
               num_classes,
               batch_size,
               classifier,
               batch_metrics=None):
    """Args:
        valid_loader: validation data loader
        model: model to validate
        criterion: loss criterion
        num_classes: number of classes
        batch_size: number of samples to process simultaneously
        classifier: True if doing a classification task, False if doing semantic segmentation
        batch_metrics: (int) Metrics computed every (int) batches. If left blank, will not perform metrics.
    """

    valid_metrics = create_metrics_dict(num_classes)
    model.eval()

    for index, data in enumerate(valid_loader):
        with torch.no_grad():
            if classifier:
                inputs, labels = data
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                outputs = model(inputs)
                outputs_flatten = outputs
            else:
                if torch.cuda.is_available():
                    inputs = data['sat_img'].cuda()
                    labels = flatten_labels(data['map_img']).cuda()
                else:
                    inputs = data['sat_img']
                    labels = flatten_labels(data['map_img'])

                outputs = model(inputs)
                outputs_flatten = flatten_outputs(outputs, num_classes)

            loss = criterion(outputs_flatten, labels)
            valid_metrics['loss'].update(loss.item(), batch_size)

            # Compute metrics every 2 batches. Time consuming.
            if batch_metrics is not None:
                if index % batch_metrics == 0:
                    a, segmentation = torch.max(outputs_flatten, dim=1)
                    valid_metrics = report_classification(
                        segmentation, labels, batch_size, valid_metrics)

    print('Validation Loss: {:.4f}'.format(valid_metrics['loss'].avg))
    if batch_metrics is not None:
        print('Validation precision: {:.4f}'.format(
            valid_metrics['precision'].avg))
        print('Validation recall: {:.4f}'.format(valid_metrics['recall'].avg))
        print('Validation f1-score: {:.4f}'.format(
            valid_metrics['fscore'].avg))

    return valid_metrics
def train(train_loader, model, criterion, optimizer, scheduler, num_classes, batch_size, classifier):
    """ Train the model and return the metrics of the training phase.
    Args:
        train_loader: training data loader
        model: model to train
        criterion: loss criterion
        optimizer: optimizer to use
        scheduler: learning rate scheduler
        num_classes: number of classes
        batch_size: number of samples to process simultaneously
        classifier: True if doing a classification task, False if doing semantic segmentation
    """
    model.train()
    scheduler.step()
    train_metrics = create_metrics_dict(num_classes)

    for index, data in enumerate(train_loader):
        if classifier:
            inputs, labels = data
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs_flatten = outputs
        else:
            if torch.cuda.is_available():
                inputs = data['sat_img'].cuda()
                labels = flatten_labels(data['map_img']).cuda()
            else:
                inputs = data['sat_img']
                labels = flatten_labels(data['map_img'])
            # forward
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs_flatten = flatten_outputs(outputs, num_classes)

        del outputs
        del inputs
        loss = criterion(outputs_flatten, labels)
        train_metrics['loss'].update(loss.item(), batch_size)

        loss.backward()
        optimizer.step()

        # Compute accuracy and iou every 2 batches and average values. Time consuming.
        if index % 2 == 0:
            a, segmentation = torch.max(outputs_flatten, dim=1)

            train_metrics = report_classification(segmentation, labels, batch_size, train_metrics)
            train_metrics = iou(segmentation, labels, batch_size, train_metrics)

    print('Training Loss: {:.4f}'.format(train_metrics['loss'].avg))
    print('Training iou: {:.4f}'.format(train_metrics['iou'].avg))
    print('Training precision: {:.4f}'.format(train_metrics['precision'].avg))
    print('Training recall: {:.4f}'.format(train_metrics['recall'].avg))
    print('Training f1-score: {:.4f}'.format(train_metrics['fscore'].avg))

    return train_metrics
def evaluation(eval_loader,
               model,
               criterion,
               num_classes,
               batch_size,
               task,
               ep_idx,
               progress_log,
               batch_metrics=None,
               dataset='val',
               num_devices=0):
    """
    Evaluate the model and return the updated metrics
    :param eval_loader: data loader
    :param model: model to evaluate
    :param criterion: loss criterion
    :param num_classes: number of classes
    :param batch_size: number of samples to process simultaneously
    :param task: segmentation or classification
    :param ep_idx: epoch index (for hypertrainer log)
    :param progress_log: progress log file (for hypertrainer log)
    :param batch_metrics: (int) Metrics computed every (int) batches. If left blank, will not perform metrics.
    :param dataset: (str) 'val or 'tst'
    :param num_devices: (int) Number of GPU devices to use.
    :return: (dict) eval_metrics
    """
    eval_metrics = create_metrics_dict(num_classes)
    model.eval()

    for index, data in enumerate(eval_loader):
        progress_log.open('a', buffering=1).write(
            tsv_line(ep_idx, dataset, index, len(eval_loader), time.time()))

        with torch.no_grad():
            if task == 'classification':
                inputs, labels = data
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                outputs = model(inputs)
                outputs_flatten = outputs
            elif task == 'segmentation':
                if num_devices > 0:
                    inputs = data['sat_img'].cuda()
                    labels = flatten_labels(data['map_img']).cuda()
                else:
                    inputs = data['sat_img']
                    labels = flatten_labels(data['map_img'])

                outputs = model(inputs)
                outputs_flatten = flatten_outputs(outputs, num_classes)

            loss = criterion(outputs_flatten, labels)
            eval_metrics['loss'].update(loss.item(), batch_size)

            if (dataset == 'val') and (batch_metrics is not None):
                # Compute metrics every n batches. Time consuming.
                if index % batch_metrics == 0:
                    a, segmentation = torch.max(outputs_flatten, dim=1)
                    eval_metrics = report_classification(
                        segmentation, labels, batch_size, eval_metrics)
            elif dataset == 'tst':
                a, segmentation = torch.max(outputs_flatten, dim=1)
                eval_metrics = report_classification(segmentation, labels,
                                                     batch_size, eval_metrics)

    print(f"{dataset} Loss: {eval_metrics['loss'].avg}")
    if batch_metrics is not None:
        print(f"{dataset} precision: {eval_metrics['precision'].avg}")
        print(f"{dataset} recall: {eval_metrics['recall'].avg}")
        print(f"{dataset} fscore: {eval_metrics['fscore'].avg}")

    return eval_metrics