Пример #1
0
def train(loader, model, optimizer, epoch, writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    iou = AverageMeter()   # semantic IoU
    iou_c = AverageMeter() # contour IoU
    iou_m = AverageMeter() # marker IoU
    print_freq = config['train'].getfloat('print_freq')
    only_contour = config['contour'].getboolean('exclusive')
    weight_map = config['param'].getboolean('weight_map')
    model_name = config['param']['model']
    with_contour = config.getboolean(model_name, 'branch_contour')
    with_marker = config.getboolean(model_name, 'branch_marker')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Sets the module in training mode.
    model.train()
    end = time.time()
    n_step = len(loader)
    for i, data in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)
        # split sample data
        inputs = data['image'].to(device)
        labels = data['label'].to(device)
        labels_c = data['label_c'].to(device)
        labels_m = data['label_m'].to(device)
        # get loss weight
        weights = None
        if weight_map and 'weight' in data:
            weights = data['weight'].to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward step
        outputs = model(inputs)
        if with_contour and with_marker:
            outputs, outputs_c, outputs_m = outputs
        elif with_contour:
            outputs, outputs_c = outputs
        # compute loss
        if only_contour:
            loss = contour_criterion(outputs, labels_c)
        else:
            # weight_criterion equals to segment_criterion if weights is none
            loss = focal_criterion(outputs, labels, weights)
            if with_contour:
                loss += focal_criterion(outputs_c, labels_c, weights)
            if with_marker:
                loss += focal_criterion(outputs_m, labels_m, weights)
        # compute gradient and do backward step
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # measure accuracy and record loss
        # NOT instance-level IoU in training phase, for better speed & instance separation handled in post-processing
        losses.update(loss.item(), inputs.size(0))
        if only_contour:
            batch_iou = iou_mean(outputs, labels_c)
        else:
            batch_iou = iou_mean(outputs, labels)
        iou.update(batch_iou, inputs.size(0))
        if with_contour:
            batch_iou_c = iou_mean(outputs_c, labels_c)
            iou_c.update(batch_iou_c, inputs.size(0))
        if with_marker:
            batch_iou_m = iou_mean(outputs_m, labels_m)
            iou_m.update(batch_iou_m, inputs.size(0))
        # log to summary
        #step = i + epoch * n_step
        #writer.add_scalar('training/loss', loss.item(), step)
        #writer.add_scalar('training/batch_elapse', batch_time.val, step)
        #writer.add_scalar('training/batch_iou', iou.val, step)
        #writer.add_scalar('training/batch_iou_c', iou_c.val, step)
        #writer.add_scalar('training/batch_iou_m', iou_m.val, step)
        if (i + 1) % print_freq == 0:
            print(
                'Epoch: [{0}][{1}/{2}]\t'
                'Time: {batch_time.avg:.2f} (io: {data_time.avg:.2f})\t'
                'Loss: {loss.val:.4f} (avg: {loss.avg:.4f})\t'
                'IoU: {iou.avg:.3f} (Coutour: {iou_c.avg:.3f}, Marker: {iou_m.avg:.3f})\t'
                .format(
                    epoch, i, n_step, batch_time=batch_time,
                    data_time=data_time, loss=losses, iou=iou, iou_c=iou_c, iou_m=iou_m
                )
            )
    # end of loop, dump epoch summary
    writer.add_scalar('training/epoch_loss', losses.avg, epoch)
    writer.add_scalar('training/epoch_iou', iou.avg, epoch)
    writer.add_scalar('training/epoch_iou_c', iou_c.avg, epoch)
    writer.add_scalar('training/epoch_iou_m', iou_m.avg, epoch)
    return iou.avg # return epoch average iou
Пример #2
0
def valid(loader, model, epoch, writer, n_step):
    iou = AverageMeter()   # semantic IoU
    iou_c = AverageMeter() # contour IoU
    iou_m = AverageMeter() # marker IoU
    losses = AverageMeter()
    only_contour = config['contour'].getboolean('exclusive')
    weight_map = config['param'].getboolean('weight_map')
    model_name = config['param']['model']
    with_contour = config.getboolean(model_name, 'branch_contour')
    with_marker = config.getboolean(model_name, 'branch_marker')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Sets the model in evaluation mode.
    model.eval()
    for i, data in enumerate(loader):
        # get the inputs
        inputs = data['image'].to(device)
        labels = data['label'].to(device)
        labels_c = data['label_c'].to(device)
        labels_m = data['label_m'].to(device)
        # get loss weight
        weights = None
        if weight_map and 'weight' in data:
            weights = data['weight'].to(device)
        # forward step
        outputs = model(inputs)
        if with_contour and with_marker:
            outputs, outputs_c, outputs_m = outputs
        elif with_contour:
            outputs, outputs_c = outputs
        # compute loss
        if only_contour:
            loss = contour_criterion(outputs, labels_c)
        else:
            # weight_criterion equals to segment_criterion if weights is none
            loss = focal_criterion(outputs, labels, weights)
            if with_contour:
                loss += focal_criterion(outputs_c, labels_c, weights)
            if with_marker:
                loss += focal_criterion(outputs_m, labels_m, weights)
        # measure accuracy and record loss (Non-instance level IoU)
        losses.update(loss.item(), inputs.size(0))
        if only_contour:
            batch_iou = iou_mean(outputs, labels_c)
        else:
            batch_iou = iou_mean(outputs, labels)
        iou.update(batch_iou, inputs.size(0))
        if with_contour:
            batch_iou_c = iou_mean(outputs_c, labels_c)
            iou_c.update(batch_iou_c, inputs.size(0))
        if with_marker:
            batch_iou_m = iou_mean(outputs_m, labels_m)
            iou_m.update(batch_iou_m, inputs.size(0))
    # end of loop, dump epoch summary
    writer.add_scalar('CV/epoch_loss', losses.avg, epoch)
    writer.add_scalar('CV/epoch_iou', iou.avg, epoch)
    writer.add_scalar('CV/epoch_iou_c', iou_c.avg, epoch)
    writer.add_scalar('CV/epoch_iou_m', iou_m.avg, epoch)
    print(
        'Epoch: [{0}]\t\tcross-validation\t'
        'Loss: N/A    (avg: {loss.avg:.4f})\t'
        'IoU: {iou.avg:.3f} (Coutour: {iou_c.avg:.3f}, Marker: {iou_m.avg:.3f})\t'
        .format(
            epoch, loss=losses, iou=iou, iou_c=iou_c, iou_m=iou_m
        )
    )
    return iou.avg # return epoch average iou
Пример #3
0
def valid(loader, model, epoch, writer, n_step):
    iou = AverageMeter()   # semantic IoU
    iou_c = AverageMeter() # contour IoU
    iou_m = AverageMeter() # marker IoU
    losses = AverageMeter()
    only_contour = config['contour'].getboolean('exclusive')
    weight_map = config['param'].getboolean('weight_map')

    # Sets the model in evaluation mode.
    model.eval()
    for i, data in enumerate(loader):
        # get the inputs
        inputs, labels, labels_c, labels_m = data['image'], data['label'], data['label_c'], data['label_m']
        if torch.cuda.is_available():
            inputs, labels, labels_c, labels_m = inputs.cuda(), labels.cuda(), labels_c.cuda(), labels_m.cuda()
        # wrap them in Variable
        inputs, labels, labels_c, labels_m = Variable(inputs), Variable(labels), Variable(labels_c), Variable(labels_m)
        # get loss weight
        weights = None
        if weight_map and 'weight' in data:
            weights = data['weight']
            if torch.cuda.is_available():
                weights = weights.cuda(async=True)
            weights = Variable(weights)
        # forward step
        outputs = model(inputs)
        if isinstance(model, CAMUNet):
            outputs, outputs_c, outputs_m = outputs
        elif isinstance(model, DCAN) or isinstance(model, CAUNet):
            outputs, outputs_c = outputs
        # compute loss
        if only_contour:
            loss = contour_criterion(outputs, labels_c)
        else:
            # weight_criterion equals to segment_criterion if weights is none
            loss = focal_criterion(outputs, labels, weights)
            if isinstance(model, CAMUNet):
                loss += focal_criterion(outputs_c, labels_c, weights)
                loss += focal_criterion(outputs_m, labels_m, weights)
            if isinstance(model, DCAN) or isinstance(model, CAUNet):
                loss += focal_criterion(outputs_c, labels_c, weights)
        # measure accuracy and record loss (Non-instance level IoU)
        losses.update(loss.data[0], inputs.size(0))
        if only_contour:
            batch_iou = iou_mean(outputs, labels_c)
        else:
            batch_iou = iou_mean(outputs, labels)
        iou.update(batch_iou, inputs.size(0))
        if isinstance(model, CAMUNet):
            batch_iou_c, batch_iou_m = iou_mean(outputs_c, labels_c), iou_mean(outputs_m, labels_m)
            iou_c.update(batch_iou_c, inputs.size(0))
            iou_m.update(batch_iou_m, inputs.size(0))
        elif isinstance(model, DCAN) or isinstance(model, CAUNet):
            batch_iou_c = iou_mean(outputs_c, labels_c)
            iou_c.update(batch_iou_c, inputs.size(0))
    # end of loop, dump epoch summary
    writer.add_scalar('CV/epoch_loss', losses.avg, epoch)
    writer.add_scalar('CV/epoch_iou', iou.avg, epoch)
    writer.add_scalar('CV/epoch_iou_c', iou_c.avg, epoch)
    writer.add_scalar('CV/epoch_iou_m', iou_m.avg, epoch)
    print(
        'Epoch: [{0}]\t\tcross-validation\t'
        'Loss: N/A    (avg: {loss.avg:.4f})\t'
        'IoU: {iou.avg:.3f} (Coutour: {iou_c.avg:.3f}, Marker: {iou_m.avg:.3f})\t'
        .format(
            epoch, loss=losses, iou=iou, iou_c=iou_c, iou_m=iou_m
        )
    )