示例#1
0
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
    print("=======Epoch:{}=======lr:{}".format(
        epoch,
        optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target, n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss0 = loss_func(output[0], target)
        loss1 = loss_func(output[1], target)
        loss2 = loss_func(output[2], target)
        loss3 = loss_func(output[3], target)

        loss = loss3 + alpha * (loss0 + loss1 + loss2)
        loss.backward()
        optimizer.step()

        train_loss.update(loss3.item(), data.size(0))
        train_dice.update(output[3], target)

    val_log = OrderedDict({
        'Train_Loss': train_loss.avg,
        'Train_dice_liver': train_dice.avg[1]
    })
    if n_labels == 3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
    return val_log
示例#2
0
def predict_one_img(model, img_dataset, args):
    dataloader = DataLoader(dataset=img_dataset,
                            batch_size=1,
                            num_workers=0,
                            shuffle=False)
    model.eval()
    test_dice = DiceAverage(args.n_labels)
    target = to_one_hot_3d(img_dataset.label, args.n_labels)

    with torch.no_grad():
        for data in tqdm(dataloader, total=len(dataloader)):
            data = data.to(device)
            output = model(data)
            # output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size
            img_dataset.update_result(output.detach().cpu())

    pred = img_dataset.recompone_result()
    pred = torch.argmax(pred, dim=1)

    pred_img = common.to_one_hot_3d(pred, args.n_labels)
    test_dice.update(pred_img, target)

    test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]})
    if args.n_labels == 3: test_dice.update({'Dice_tumor': test_dice.avg[2]})

    pred = np.asarray(pred.numpy(), dtype='uint8')
    if args.postprocess:
        pass  # TO DO
    pred = sitk.GetImageFromArray(np.squeeze(pred, axis=0))

    return test_dice, pred
示例#3
0
def val(model, val_loader):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            target = common.to_one_hot_3d(target.long())
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    return OrderedDict({
        'Val Loss': val_loss,
        'Val dice0': val_dice0,
        'Val dice1': val_dice1,
        'Val dice2': val_dice2
    })
示例#4
0
def val(model, val_loader, criterion, n_labels):
    model.eval()
    val_loss = metrics.LossAverage()
    val_dice = metrics.DiceAverage(n_labels)
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            data, target = data.float(), target.long()
            target = common.to_one_hot_3d(target, n_labels)
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss.update(loss.item(), data.size(0))
            val_dice.update(output, target)
    if n_labels == 2:
        return OrderedDict({
            'Val Loss': val_loss.avg,
            'Val dice0': val_dice.avg[0],
            'Val dice1': val_dice.avg[1]
        })
    else:
        return OrderedDict({
            'Val Loss': val_loss.avg,
            'Val dice0': val_dice.avg[0],
            'Val dice1': val_dice.avg[1],
            'Val dice2': val_dice.avg[2]
        })
示例#5
0
def train(model, train_loader, optimizer, epoch, logger):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        target = common.to_one_hot_3d(target.long())
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)
        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += loss
        train_dice0 += metrics.dice(output, target, 0)
        train_dice1 += metrics.dice(output, target, 1)
        train_dice2 += metrics.dice(output, target, 2)
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    print(
        'Train Epoch: {} \tLoss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}'
        .format(epoch, train_loss, train_dice0, train_dice1, train_dice2))

    logger.scalar_summary('train_loss', float(train_loss), epoch)
    logger.scalar_summary('train_dice0', float(train_dice0), epoch)
    logger.scalar_summary('train_dice1', float(train_dice1), epoch)
    logger.scalar_summary('train_dice2', float(train_dice2), epoch)
示例#6
0
def val(model, val_loader, loss_func, n_labels):
    model.eval()
    val_loss = metrics.LossAverage()
    val_dice = metrics.DiceAverage(n_labels)
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            data, target = data.float(), target.long()
            target = common.to_one_hot_3d(target, n_labels)
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_func(output, target)

            val_loss.update(loss.item(), data.size(0))
            val_dice.update(output, target)
    val_log = OrderedDict({
        'Val_Loss': val_loss.avg,
        'Val_dice_liver': val_dice.avg[1]
    })
    if n_labels == 3: val_log.update({'Val_dice_tumor': val_dice.avg[2]})
    return val_log
示例#7
0
def train(model, train_loader, optimizer):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        target = common.to_one_hot_3d(target.long())
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)
        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        # loss = metrics.DiceMeanLoss()(output, target)
        loss = metrics.WeightDiceLoss()(output, target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += float(loss)
        train_dice0 += float(metrics.dice(output, target, 0))
        train_dice1 += float(metrics.dice(output, target, 1))
        train_dice2 += float(metrics.dice(output, target, 2))
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    return OrderedDict({
        'Train Loss': train_loss,
        'Train dice0': train_dice0,
        'Train dice1': train_dice1,
        'Train dice2': train_dice2
    })
示例#8
0
def train(model, train_loader, optimizer, criterion, n_labels):
    print("=======Epoch:{}=======lr:{}".format(
        epoch,
        optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target, n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        # if idx==0:
        #     print(output.shape)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), data.size(0))
        train_dice.update(output, target)

    if n_labels == 2:
        return OrderedDict({
            'Train Loss': train_loss.avg,
            'Train dice0': train_dice.avg[0],
            'Train dice1': train_dice.avg[1]
        })
    else:
        return OrderedDict({
            'Train Loss': train_loss.avg,
            'Train dice0': train_dice.avg[0],
            'Train dice1': train_dice.avg[1],
            'Train dice2': train_dice.avg[2]
        })
示例#9
0
def val(model, val_loader, epoch, logger):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            target = common.to_one_hot_3d(target.long())
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    logger.scalar_summary('val_loss', val_loss, epoch)
    logger.scalar_summary('val_dice0', val_dice0, epoch)
    logger.scalar_summary('val_dice1', val_dice1, epoch)
    logger.scalar_summary('val_dice2', val_dice2, epoch)
    print(
        'Val performance: Average loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'
        .format(val_loss, val_dice0, val_dice1, val_dice2))