Exemple #1
0
def train_epoch(model,
                loader,
                optimizer,
                epoch,
                n_epochs,
                print_freq=1,
                writer=None):
    meters = MultiAverageMeter()
    # Model on train mode
    model.train()

    end = time.time()
    for batch_idx, (x, y) in enumerate(loader):
        # Create vaiables
        x = to_var(x)
        y = to_var(y)

        # compute output
        pred_logit = model(x)
        y_one_hot = categorical_to_one_hot(y, dim=1, expand_dim=False)

        loss = soft_dice_loss(pred_logit, y_one_hot)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        iou = cal_iou(pred_logit, y_one_hot)
        dice = cal_dice(pred_logit, y_one_hot)

        logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                            [iou[i].item() for i in range(len(iou))]+ \
                            [dice[i].item() for i in range(len(dice))]+ \
                            [time.time() - end]
        meters.update(logs, y.size(0))

        # measure elapsed time
        end = time.time()

        # print stats
        print_freq = 2 // meters.val[-1] + 1
        if batch_idx % print_freq == 0:
            res = '\t'.join([
                'Epoch: [%d/%d]' % (epoch + 1, n_epochs),
                'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
            ])
            print(res)
    pred_one_hot = categorical_to_one_hot(pred_logit.argmax(dim=1),
                                          dim=1,
                                          expand_dim=True)
    plot_multi_voxels(pred_one_hot[0], y_one_hot[0])
    plt.savefig(
        os.path.join(cfg.save, 'epoch_{}'.format(epoch),
                     'train_{}.pdf'.format(epoch)))
    plt.close()
    return meters.avg[:-1]
def train_epoch(model,
                loader,
                optimizer,
                epoch,
                n_epochs,
                print_freq=1,
                writer=None):
    meters = MultiAverageMeter()
    # Model on train mode
    model.train()
    global iteration
    end = time.time()
    for batch_idx, (x, y) in enumerate(loader):
        # Create vaiables
        x = to_var(x)
        y = to_var(y)
        # compute output
        pred_logit = model(x)
        loss = soft_dice_loss(pred_logit, y, smooth=1e-2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        y = y.long()

        batch_size = y.size(0)
        iou = cal_batch_iou(pred_logit, y)
        dice = cal_batch_dice(pred_logit, y)

        logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                            [iou[i].item() for i in range(len(iou))]+ \
                            [dice[i].item() for i in range(len(dice))]+ \
                            [time.time() - end]
        meters.update(logs, batch_size)
        writer.add_scalar('train_loss_logs', loss.item(), iteration)
        with open(os.path.join(cfg.save, 'loss_logs.csv'), 'a') as f:
            f.write('%09d,%0.6f,\n' % (
                (iteration + 1),
                loss.item(),
            ))
        iteration += 1

        # measure elapsed time
        end = time.time()
        # print stats
        print_freq = 2 // meters.val[-1] + 1
        if batch_idx % print_freq == 0:
            res = '\t'.join([
                'Epoch: [%d/%d]' % (epoch + 1, n_epochs),
                'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
            ])
            print(res)

    return meters.avg[:-1]  #intersection, union
def train_epoch(model, loader, optimizer, epoch, results_logger):
    '''
    One training epoch
    '''
    meters = AverageMeter()
    # Model on train mode
    model.train()
    global iteration
    intersection = 0
    union = 0
    for batch_idx, (x, y) in enumerate(loader):
        x = to_device(x)
        y = to_device(y)
        # forward and backward
        pred_logit = model(x)
        y_one_hot = categorical_to_one_hot(y, dim=1, expand_dim=False)

        loss = soft_dice_loss(pred_logit, y_one_hot)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # calculate metrics
        pred_classes = pred_logit.argmax(1)
        intersection += ((pred_classes==1) * (y[:,0]==1)).sum().item()
        union += ((pred_classes==1).sum() + y[:,0].sum()).item()
        batch_size = y.size(0)

        iou = cal_batch_iou(pred_logit, y_one_hot)
        dice = cal_batch_dice(pred_logit, y_one_hot)
        # log
        writer.add_scalar('train_loss_logs', loss.item(), iteration)
        with open(os.path.join(cfg.save, 'loss_logs.csv'), 'a') as f:
            f.write('%09d,%0.6f,\n'%((iteration + 1),loss.item(),))
        iteration += 1

        logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                            [iou[i].item() for i in range(len(iou))]+ \
                            [dice[i].item() for i in range(len(dice))]
        meters.update(logs, batch_size)   

        # print stats
        print_freq = 2 // meters.val[-1] + 1
        if batch_idx % print_freq == 0:
            res = '\t'.join([
                'Epoch: [%d/%d]' % (epoch + 1, cfg.n_epochs),
                'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
            ])
            print(res)
    dice_global = 2. * intersection / union
    return meters.avg[:-1] + [dice_global]
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    meters = MultiAverageMeter()
    # Model on eval mode
    model.eval()
    gt_classes = []
    pred_all_probs = []
    end = time.time()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):

            x = to_var(x)

            y = to_var(y)

            pred_logit = model(x)

            # calculate metrics
            pred_class = pred_logit.max(dim=1)[1]
            pred_probs = pred_logit.softmax(-1)
            pred_all_probs.append(pred_probs.cpu())
            gt_classes.append(y.cpu())

            #print(gt_classes.shape) #pred_class[20,48,48,48]
            #print(pred_probs[1]) #y e pred_probs[20,6,48,48,48]

            batch_size, n_classes = pred_logit.shape[:2]

            loss = soft_dice_loss(pred_logit, y, smooth=1e-2)
            y = y.long()
            batch_size = y.size(0)
            iou = cal_batch_iou(pred_logit, y)
            dice = cal_batch_dice(pred_logit, y)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]+ \
                                [time.time() - end]
            meters.update(logs, batch_size)

            end = time.time()

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)

    return meters.avg[:-1]
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    """
    One test epoch
    """
    meters = MultiAverageMeter()
    model.eval()
    intersection = 0
    union = 0
    end = time.time()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x = to_var(x)
            y = to_var(y)
            # forward
            pred_logit = model(x)
            # calculate metrics
            y_one_hot = categorical_to_one_hot(y, dim=1, expand_dim=False)
            pred_classes = pred_logit.argmax(1)
            intersection += ((pred_classes == 1) * (y[:, 0] == 1)).sum().item()
            union += ((pred_classes == 1).sum() + y[:, 0].sum()).item()

            loss = soft_dice_loss(pred_logit, y_one_hot)
            batch_size = y.size(0)

            iou = cal_batch_iou(pred_logit, y_one_hot)
            dice = cal_batch_dice(pred_logit, y_one_hot)

            logs = (
                [loss.item(), iou[1:].mean(), dice[1:].mean()]
                + [iou[i].item() for i in range(len(iou))]
                + [dice[i].item() for i in range(len(dice))]
                + [time.time() - end]
            )
            meters.update(logs, batch_size)

            end = time.time()

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = "\t".join(
                    [
                        "Test" if is_test else "Valid",
                        "Iter: [%d/%d]" % (batch_idx + 1, len(loader)),
                        "Time %.3f (%.3f)" % (meters.val[-1], meters.avg[-1]),
                        "Loss %.4f (%.4f)" % (meters.val[0], meters.avg[0]),
                        "IOU %.4f (%.4f)" % (meters.val[1], meters.avg[1]),
                        "DICE %.4f (%.4f)" % (meters.val[2], meters.avg[2]),
                    ]
                )
                print(res)
    dice_global = 2.0 * intersection / union

    return meters.avg[:-1] + [dice_global]
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    meters = MultiAverageMeter()
    # Model on eval mode

    global iteration

    intersection = 0
    union = 0
    model.eval()
    end = time.time()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            # Create vaiables
            x = to_var(x)
            y = to_var(y)
            # compute output
            pred_logit = model(x)
            loss = soft_dice_loss(pred_logit, y, smooth=1e-2)

            #auroc= AUROC_per_case(pred_logit, y)

            # measure accuracy and record loss
            batch_size = y.size(0)
            iou = cal_batch_iou(pred_logit, y)
            dice = cal_batch_dice(pred_logit, y)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]+ \
                                [time.time() - end]
            meters.update(logs, batch_size)

            # measure elapsed time
            end = time.time()

            # print stats
            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)
    return meters.avg[:-1]
def test_epoch(model, loader, optimizer, epoch, results_logger):
    '''
    One test epoch
    '''
    meters = AverageMeter()
    model.eval()
    intersection = 0
    union = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x = to_device(x)
            y = to_device(y)
            # forward
            pred_logit = model(x)
            # calculate metrics
            y_one_hot = categorical_to_one_hot(y, dim=1, expand_dim=False)
            pred_classes = pred_logit.argmax(1)
            intersection += ((pred_classes==1) * (y[:,0]==1)).sum().item()
            union += ((pred_classes==1).sum() + y[:,0].sum()).item()

            loss = soft_dice_loss(pred_logit, y_one_hot)
            batch_size = y.size(0)

            iou = cal_batch_iou(pred_logit, y_one_hot)
            dice = cal_batch_dice(pred_logit, y_one_hot)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]
            meters.update(logs, batch_size)   

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)
    dice_global = 2. * intersection / union

    return meters.avg[:-1] + [dice_global]
Exemple #8
0
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    meters = MultiAverageMeter()
    # Model on eval mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x = to_var(x)
            y = to_var(y)
            pred_logit = model(x)
            loss = soft_dice_loss(pred_logit, y, smooth=1e-2)
            y = y.long()
            batch_size = y.size(0)
            iou = cal_batch_iou(pred_logit, y)
            dice = cal_batch_dice(pred_logit, y)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]+ \
                                [time.time() - end]
            meters.update(logs, batch_size)

            end = time.time()

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)

    return meters.avg[:-1]
def train_epoch(model, loader, optimizer, epoch, n_epochs, print_freq=1, writer=None):
    """
    One training epoch
    """
    meters = MultiAverageMeter()
    # Model on train mode
    model.train()
    global iteration
    intersection = 0
    union = 0
    end = time.time()
    for batch_idx, (x, y) in enumerate(loader):
        x = to_var(x)
        y = to_var(y)
        # forward and backward
        pred_logit = model(x)
        y_one_hot = categorical_to_one_hot(y, dim=1, expand_dim=False)

        loss = soft_dice_loss(pred_logit, y_one_hot)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # calculate metrics
        pred_classes = pred_logit.argmax(1)
        intersection += ((pred_classes == 1) * (y[:, 0] == 1)).sum().item()
        union += ((pred_classes == 1).sum() + y[:, 0].sum()).item()
        batch_size = y.size(0)

        iou = cal_batch_iou(pred_logit, y_one_hot)
        dice = cal_batch_dice(pred_logit, y_one_hot)
        # log
        writer.add_scalar("train_loss_logs", loss.item(), iteration)
        with open(os.path.join(cfg.save, "loss_logs.csv"), "a") as f:
            f.write(
                "%09d,%0.6f,\n"
                % (
                    (iteration + 1),
                    loss.item(),
                )
            )
        iteration += 1

        logs = (
            [loss.item(), iou[1:].mean(), dice[1:].mean()]
            + [iou[i].item() for i in range(len(iou))]
            + [dice[i].item() for i in range(len(dice))]
            + [time.time() - end]
        )
        meters.update(logs, batch_size)
        end = time.time()

        # print stats
        print_freq = 2 // meters.val[-1] + 1
        if batch_idx % print_freq == 0:
            res = "\t".join(
                [
                    "Epoch: [%d/%d]" % (epoch + 1, n_epochs),
                    "Iter: [%d/%d]" % (batch_idx + 1, len(loader)),
                    "Time %.3f (%.3f)" % (meters.val[-1], meters.avg[-1]),
                    "Loss %.4f (%.4f)" % (meters.val[0], meters.avg[0]),
                    "IOU %.4f (%.4f)" % (meters.val[1], meters.avg[1]),
                    "DICE %.4f (%.4f)" % (meters.val[2], meters.avg[2]),
                ]
            )
            print(res)
    dice_global = 2.0 * intersection / union
    return meters.avg[:-1] + [dice_global]
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    '''
    One test epoch
    '''
    meters = MultiAverageMeter()
    model.eval()
    intersection = 0
    union = 0
    end = time.time()
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x = x.float()
            x = to_var(x)
            y = to_var(y)
            # forward
            pred_logit = model(x, False)
            # calculate metrics
            y_one_hot = categorical_to_one_hot(y,
                                               dim=1,
                                               expand_dim=False,
                                               n_classes=3)
            pred_classes = pred_logit.argmax(1)
            intersection += (
                (pred_classes == 1) * (y[:, 0] == 1)).sum().item() + (
                    (pred_classes == 2) *
                    (y[:, 0] == 2)).sum().item()  # maybe inaccurate
            union += ((pred_classes == 1).sum() +
                      (y[:, 0] == 1).sum()).item() + (
                          (pred_classes == 2).sum() +
                          (y[:, 0] == 2).sum()).item()
            # intersection += ((pred_classes==1) * (y[:,0]==1)).sum().item()
            # union += ((pred_classes==1).sum() + y[:,0].sum()).item()

            loss = soft_dice_loss(pred_logit, y_one_hot)
            batch_size = y.size(0)

            iou = cal_batch_iou(pred_logit, y_one_hot)
            dice = cal_batch_dice(pred_logit, y_one_hot)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]+ \
                                [time.time() - end]
            meters.update(logs, batch_size)

            end = time.time()

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)
    dice_global = 2. * intersection / union

    return meters.avg[:-1] + [dice_global]
def train_epoch(model,
                loader,
                optimizer,
                epoch,
                n_epochs,
                print_freq=1,
                writer=None):
    '''
    One training epoch
    '''
    meters = MultiAverageMeter()
    # Model on train mode
    model.train()
    global iteration
    intersection = 0
    union = 0
    end = time.time()
    for batch_idx, (x, y) in enumerate(loader):
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        x = x.float()
        x = to_var(x)
        y = to_var(y)
        # forward and backward
        pred_logit = model(x, True)
        y_one_hot = categorical_to_one_hot(y,
                                           dim=1,
                                           expand_dim=False,
                                           n_classes=3)  # b*n*h*w*d
        # print(pred_logit.size(),y_one_hot.size())
        loss = soft_dice_loss(pred_logit, y_one_hot)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # calculate metrics
        pred_classes = pred_logit.argmax(1)
        intersection += ((pred_classes == 1) * (y[:, 0] == 1)).sum().item() + (
            (pred_classes == 2) *
            (y[:, 0] == 2)).sum().item()  # maybe inaccurate
        union += ((pred_classes == 1).sum() +
                  (y[:, 0] == 1).sum()).item() + ((pred_classes == 2).sum() +
                                                  (y[:, 0] == 2).sum()).item()
        batch_size = y.size(0)

        iou = cal_batch_iou(pred_logit, y_one_hot)  # n
        dice = cal_batch_dice(pred_logit, y_one_hot)  # n
        # log
        writer.add_scalar('train_loss_logs', loss.item(), iteration)
        with open(os.path.join(cfg.save, 'loss_logs.csv'), 'a') as f:
            f.write('%09d,%0.6f,\n' % (
                (iteration + 1),
                loss.item(),
            ))
        iteration += 1

        logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                            [iou[i].item() for i in range(len(iou))]+ \
                            [dice[i].item() for i in range(len(dice))]+ [lr]+\
                            [time.time() - end]
        meters.update(logs, batch_size)
        end = time.time()

        # print stats
        print_freq = 2 // meters.val[-1] + 1
        if batch_idx % print_freq == 0:
            res = '\t'.join([
                'Epoch: [%d/%d]' % (epoch + 1, n_epochs),
                'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
            ])
            print(res)
        torch.cuda.empty_cache()
    dice_global = 2. * intersection / union

    return meters.avg[:-1] + [dice_global]
Exemple #12
0
def test_epoch(model, loader, epoch, print_freq=1, is_test=True, writer=None):
    '''
    One test epoch
    '''
    meters = MultiAverageMeter()
    model.eval()
    intersection = 0
    union = 0
    end = time.time()
    centers = [[24, 24, 24], [24, 24, 72], [24, 72, 24], [72, 24, 24],
               [24, 72, 72], [72, 24, 72], [72, 72, 24], [72, 72, 72]]
    width = 24
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x = x.float()
            x = to_var(x)
            y = to_var(y)
            pred_logits = torch.zeros((y.size(0), 3, 48, 48, 0))
            pred_logits = pred_logits.float()
            pred_logits = to_var(pred_logits)
            y_one_hots = torch.zeros((y.size(0), 3, 48, 48, 0))
            y_one_hots = y_one_hots.long()
            y_one_hots = to_var(y_one_hots)
            # forward
            for center in centers:
                pred_logit = model(
                    x[:, :, center[0] - width:center[0] + width,
                      center[1] - width:center[1] + width,
                      center[2] - width:center[2] + width])  # 8*3*48*48*48
                pred_logits = torch.cat([pred_logits, pred_logit], 4)
                # calculate metrics
                tmp = y[:, :, center[0] - width:center[0] + width,
                        center[1] - width:center[1] + width,
                        center[2] - width:center[2] + width]  # 8*1*48*48*48
                y_one_hot = categorical_to_one_hot(tmp,
                                                   dim=1,
                                                   expand_dim=False,
                                                   n_classes=3)
                y_one_hots = torch.cat([y_one_hots, y_one_hot], 4)
                # print(pred_logit.size(),y_one_hot.size(),y.size())
                pred_classes = pred_logit.argmax(1)  # 8*48*48*48
                intersection += (
                    (pred_classes == 1) * (tmp[:, 0] == 1)).sum().item() + (
                        (pred_classes == 2) *
                        (tmp[:, 0] == 2)).sum().item()  # maybe inaccurate
                union += ((pred_classes == 1).sum() +
                          (tmp[:, 0] == 1).sum()).item() + (
                              (pred_classes == 2).sum() +
                              (tmp[:, 0] == 2).sum()).item()
            # intersection += ((pred_classes==1) * (y[:,0]==1)).sum().item()
            # union += ((pred_classes==1).sum() + y[:,0].sum()).item()
            loss = soft_dice_loss(pred_logits, y_one_hots)
            batch_size = y.size(0)

            iou = cal_batch_iou(pred_logits, y_one_hots)
            dice = cal_batch_dice(pred_logits, y_one_hots)

            logs = [loss.item(), iou[1:].mean(), dice[1:].mean()]+ \
                                [iou[i].item() for i in range(len(iou))]+ \
                                [dice[i].item() for i in range(len(dice))]+ \
                                [time.time() - end]
            meters.update(logs, batch_size)

            end = time.time()

            print_freq = 2 // meters.val[-1] + 1
            if batch_idx % print_freq == 0:
                res = '\t'.join([
                    'Test' if is_test else 'Valid',
                    'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
                    'Time %.3f (%.3f)' % (meters.val[-1], meters.avg[-1]),
                    'Loss %.4f (%.4f)' % (meters.val[0], meters.avg[0]),
                    'IOU %.4f (%.4f)' % (meters.val[1], meters.avg[1]),
                    'DICE %.4f (%.4f)' % (meters.val[2], meters.avg[2]),
                ])
                print(res)
    dice_global = 2. * intersection / union

    return meters.avg[:-1] + [dice_global]