Esempio n. 1
0
def test(args, model, device, test_loader, meters, epoch, criterion):
    testloss = meters['loss']
    testdice = meters['dice']

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, mask) in enumerate(test_loader):
            data = data.unsqueeze(1).float()
            mask = mask.unsqueeze(1).float()
            data, mask = data.to(device), mask.to(device)
            output = model(data)
            loss = criterion(output, mask)
            #loss = F.binary_cross_entropy_with_logits(output, mask, reduction='sum').item()
            test_loss += loss
            dice = dice_coefficient(output, mask)
            testdice.update(dice)
            testloss.update(loss)

            info = {'test_loss': loss, 'test_dice': testdice.avg}

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch)

    test_loss /= len(test_loader.dataset)

    print(
        '\nTest set: Average loss: {:.4f}, Average Dice Coefficient: {:.6f}\n'.
        format(testloss.avg, testdice.avg))
Esempio n. 2
0
def train(args,
          model,
          device,
          train_loader,
          optimizer,
          epoch,
          meters,
          criterion,
          savepath=None,
          savefile=None):
    trainloss = meters['loss']
    traindice = meters['dice']

    model.train()
    for batch_idx, (data, mask) in enumerate(train_loader):
        data = data.unsqueeze(1).float()
        mask = mask.unsqueeze(1).float()
        data, mask = data.to(device), mask.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, mask)
        #loss = F.binary_cross_entropy_with_logits(output, mask)
        dice = dice_coefficient(output, mask)
        loss.backward()
        optimizer.step()
        trainloss.update(loss.item())
        traindice.update(dice)

        if batch_idx % args.log_interval == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Dice: {:.6f}'
                .format(epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.item(),
                        traindice.avg))

            info = {'train_loss': loss.item(), 'train_dice': traindice.avg}

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch)

            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/')
                logger.histo_summary(tag, value.data.cpu().numpy(), epoch)
                logger.histo_summary(tag + '/grad',
                                     value.grad.data.cpu().numpy(), epoch)

            imgs = output.squeeze(1)
            imgs = output.view(-1, 512, 512)[:2].detach().cpu().numpy()
            info = {'segmentations': imgs}

            for tag, images in info.items():
                logger.image_summary(tag, images, epoch)

            if args.checkpoint:
                save_checkpoint(model, optimizer, epoch, loss, savepath,
                                savefile)
Esempio n. 3
0
def train(args, model, start_gpu, end_gpu, train_loader, optimizer, epoch,
          meters, criterion):
    trainloss = meters['loss']
    traindice = meters['dice']

    model.train()
    for batch_idx, (data, mask) in enumerate(train_loader):
        data = data.unsqueeze(1).float()
        mask = mask.unsqueeze(1).float()
        data = downsample_img(data)
        mask = downsample_mask(mask)
        data, mask = data.to(start_gpu), mask.to(end_gpu)
        optimizer.zero_grad()
        output = model(data)
        #loss = F.binary_cross_entropy_with_logits(output, mask, reduction='mean')
        output = torch.sigmoid(output)
        loss = criterion(output, mask)
        with torch.no_grad():
            output_binary = output > 0.5
            dice = dice_coefficient(output_binary.float(), mask)
        loss.backward()
        optimizer.step()
        trainloss.update(loss.item())
        traindice.update(dice)

        if batch_idx % args.log_interval == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Dice: {:.6f}'
                .format(epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.item(),
                        traindice.avg))

            info = {'train_loss': loss.item(), 'train_dice': traindice.avg}

            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch)

            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/')
                logger.histo_summary(tag, value.data.cpu().numpy(), epoch)
                logger.histo_summary(tag + '/grad',
                                     value.grad.data.cpu().numpy(), epoch)

            imgs = output.view(-1, 256, 256)[:2].detach().cpu().numpy()
            info = {'segmentations': imgs}

            for tag, images in info.items():
                logger.image_summary(tag, images, epoch)