def train(args, train_loader, model, criterion, optimizer, epoch): losses = AverageMeter() ious = AverageMeter() dices_1s = AverageMeter() dices_2s = AverageMeter() model.train() for i, (input, target) in tqdm(enumerate(train_loader), total=len(train_loader)): input = input.cuda() target = target.cuda() # compute output if args.deepsupervision: outputs = model(input) loss = 0 for output in outputs: loss += criterion(output, target) loss /= len(outputs) iou = iou_score(outputs[-1], target) else: output = model(input) loss = criterion(output, target) iou = iou_score(output, target) dice_1 = dice_coef(output, target)[0] dice_2 = dice_coef(output, target)[1] losses.update(loss.item(), input.size(0)) ious.update(iou, input.size(0)) dices_1s.update(torch.tensor(dice_1), input.size(0)) dices_2s.update(torch.tensor(dice_2), input.size(0)) # compute gradient and do optimizing step optimizer.zero_grad() loss.backward() optimizer.step() log = OrderedDict([ ('loss', losses.avg), ('iou', ious.avg), ('dice_1', dices_1s.avg), ('dice_2', dices_2s.avg) ]) return log
def validate(args, val_loader, model, criterion): losses = AverageMeter() ious = AverageMeter() dices_1s = AverageMeter() dices_2s = AverageMeter() # switch to evaluate mode model.eval() with torch.no_grad(): for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)): input = input.cuda() target = target.cuda() # compute output if args.deepsupervision: outputs = model(input) loss = 0 for output in outputs: loss += criterion(output, target) loss /= len(outputs) iou = iou_score(outputs[-1], target) else: output = model(input) loss = criterion(output, target) iou = iou_score(output, target) dice_1 = dice_coef(output, target)[0] dice_2 = dice_coef(output, target)[1] losses.update(loss.item(), input.size(0)) ious.update(iou, input.size(0)) dices_1s.update(torch.tensor(dice_1), input.size(0)) dices_2s.update(torch.tensor(dice_2), input.size(0)) log = OrderedDict([ ('loss', losses.avg), ('iou', ious.avg), ('dice_1', dices_1s.avg), ('dice_2', dices_2s.avg) ]) return log