コード例 #1
0
def test(model, test_loader):
    print("Evaluation of Testset Starting...")
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(test_loader)
    val_dice0 /= len(test_loader)
    val_dice1 /= len(test_loader)
    val_dice2 /= len(test_loader)

    print('\nTest set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
コード例 #2
0
def val(model, val_loader):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            target = common.to_one_hot_3d(target.long())
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    return OrderedDict({
        'Val Loss': val_loss,
        'Val dice0': val_dice0,
        'Val dice1': val_dice1,
        'Val dice2': val_dice2
    })
コード例 #3
0
ファイル: test.py プロジェクト: sainatarajan/3DUNet-Pytorch
def test(model, dataset, save_path, filename):
    dataloader = DataLoader(dataset=dataset,
                            batch_size=4,
                            num_workers=0,
                            shuffle=False)
    model.eval()
    save_tool = Recompone_tool(save_path, filename, dataset.ori_shape,
                               dataset.new_shape, dataset.cut)
    target = torch.from_numpy(np.expand_dims(dataset.label_np, axis=0)).long()
    target = to_one_hot_3d(target)
    with torch.no_grad():
        for data in tqdm(dataloader, total=len(dataloader)):
            data = data.unsqueeze(1)
            data = data.float().to(device)
            output = model(data)
            save_tool.add_result(output.detach().cpu())

    pred = save_tool.recompone_overlap()
    pred = torch.unsqueeze(pred, dim=0)
    val_loss = metrics.DiceMeanLoss()(pred, target)
    val_dice0 = metrics.dice(pred, target, 0)
    val_dice1 = metrics.dice(pred, target, 1)
    val_dice2 = metrics.dice(pred, target, 2)

    pred_img = torch.argmax(pred, dim=1)
    img = sitk.GetImageFromArray(
        np.squeeze(np.array(pred_img.numpy(), dtype='uint8'), axis=0))
    sitk.WriteImage(img, os.path.join(save_path, filename))

    # save_tool.save(filename)
    print(
        '\nAverage loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'
        .format(val_loss, val_dice0, val_dice1, val_dice2))
    return val_loss, val_dice0, val_dice1, val_dice2
コード例 #4
0
def val(model, val_loader, epoch, logger):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    logger.scalar_summary('val_loss', val_loss, epoch)
    logger.scalar_summary('val_dice0', val_dice0, epoch)
    logger.scalar_summary('val_dice1', val_dice1, epoch)
    logger.scalar_summary('val_dice2', val_dice2, epoch)
    print('\nVal set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
コード例 #5
0
def train(model, train_loader, optimizer, epoch, logger):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += loss
        train_dice0 += metrics.dice(output, target, 0)
        train_dice1 += metrics.dice(output, target, 1)
        train_dice2 += metrics.dice(output, target, 2)
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    print(
        'Train Epoch: {} \tLoss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}'
        .format(epoch, train_loss, train_dice0, train_dice1, train_dice2))

    logger.scalar_summary('train_loss', float(train_loss), epoch)
    logger.scalar_summary('train_dice0', float(train_dice0), epoch)
    logger.scalar_summary('train_dice1', float(train_dice1), epoch)
    logger.scalar_summary('train_dice2', float(train_dice2), epoch)
コード例 #6
0
def train(model, train_loader):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += float(loss)
        train_dice0 += float(metrics.dice(output, target, 0))
        train_dice1 += float(metrics.dice(output, target, 1))
        train_dice2 += float(metrics.dice(output, target, 2))
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    return OrderedDict({
        'Train Loss': train_loss,
        'Train dice0': train_dice0,
        'Train dice1': train_dice1,
        'Train dice2': train_dice2
    })
コード例 #7
0
def train(model, train_loader, optimizer, epoch, logger):
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss = loss
        train_dice0 = metrics.dice(output, target, 0)
        train_dice1 = metrics.dice(output, target, 1)
        train_dice2 = metrics.dice(output, target, 2)
        print(
            'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tdice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tT: {:.6f}\tP: {:.6f}\tTP: {:.6f}'
            .format(epoch, batch_idx,
                    len(train_loader), 100. * batch_idx / len(train_loader),
                    loss.item(), train_dice0, train_dice1, train_dice2,
                    metrics.T(output, target), metrics.P(output, target),
                    metrics.TP(output, target)))

    logger.scalar_summary('train_loss', float(train_loss), epoch)
    logger.scalar_summary('train_dice0', float(train_dice0), epoch)
    logger.scalar_summary('train_dice1', float(train_dice1), epoch)
    logger.scalar_summary('train_dice2', float(train_dice2), epoch)
コード例 #8
0
def val(model, val_loader, epoch, logger):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            data = torch.squeeze(data, dim=0)
            target = torch.squeeze(target, dim=0)
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    logger.scalar_summary('val_loss', val_loss, epoch)
    logger.scalar_summary('val_dice0', val_dice0, epoch)
    logger.scalar_summary('val_dice1', val_dice1, epoch)
    logger.scalar_summary('val_dice2', val_dice2, epoch)
    print(
        'Val performance: Average loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'
        .format(val_loss, val_dice0, val_dice1, val_dice2))