def train(model, train_loader, optimizer, loss_func, n_labels, alpha): print("=======Epoch:{}=======lr:{}".format( epoch, optimizer.state_dict()['param_groups'][0]['lr'])) model.train() train_loss = metrics.LossAverage() train_dice = metrics.DiceAverage(n_labels) for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): data, target = data.float(), target.long() target = common.to_one_hot_3d(target, n_labels) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss0 = loss_func(output[0], target) loss1 = loss_func(output[1], target) loss2 = loss_func(output[2], target) loss3 = loss_func(output[3], target) loss = loss3 + alpha * (loss0 + loss1 + loss2) loss.backward() optimizer.step() train_loss.update(loss3.item(), data.size(0)) train_dice.update(output[3], target) val_log = OrderedDict({ 'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1] }) if n_labels == 3: val_log.update({'Train_dice_tumor': train_dice.avg[2]}) return val_log
def predict_one_img(model, img_dataset, args): dataloader = DataLoader(dataset=img_dataset, batch_size=1, num_workers=0, shuffle=False) model.eval() test_dice = DiceAverage(args.n_labels) target = to_one_hot_3d(img_dataset.label, args.n_labels) with torch.no_grad(): for data in tqdm(dataloader, total=len(dataloader)): data = data.to(device) output = model(data) # output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size img_dataset.update_result(output.detach().cpu()) pred = img_dataset.recompone_result() pred = torch.argmax(pred, dim=1) pred_img = common.to_one_hot_3d(pred, args.n_labels) test_dice.update(pred_img, target) test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]}) if args.n_labels == 3: test_dice.update({'Dice_tumor': test_dice.avg[2]}) pred = np.asarray(pred.numpy(), dtype='uint8') if args.postprocess: pass # TO DO pred = sitk.GetImageFromArray(np.squeeze(pred, axis=0)) return test_dice, pred
def val(model, val_loader): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): target = common.to_one_hot_3d(target.long()) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) return OrderedDict({ 'Val Loss': val_loss, 'Val dice0': val_dice0, 'Val dice1': val_dice1, 'Val dice2': val_dice2 })
def val(model, val_loader, criterion, n_labels): model.eval() val_loss = metrics.LossAverage() val_dice = metrics.DiceAverage(n_labels) with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): data, target = data.float(), target.long() target = common.to_one_hot_3d(target, n_labels) data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) val_loss.update(loss.item(), data.size(0)) val_dice.update(output, target) if n_labels == 2: return OrderedDict({ 'Val Loss': val_loss.avg, 'Val dice0': val_dice.avg[0], 'Val dice1': val_dice.avg[1] }) else: return OrderedDict({ 'Val Loss': val_loss.avg, 'Val dice0': val_dice.avg[0], 'Val dice1': val_dice.avg[1], 'Val dice2': val_dice.avg[2] })
def train(model, train_loader, optimizer, epoch, logger): print("=======Epoch:{}=======".format(epoch)) model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): target = common.to_one_hot_3d(target.long()) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) # loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) loss = metrics.DiceMeanLoss()(output, target) # loss=metrics.WeightDiceLoss()(output,target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss += loss train_dice0 += metrics.dice(output, target, 0) train_dice1 += metrics.dice(output, target, 1) train_dice2 += metrics.dice(output, target, 2) train_loss /= len(train_loader) train_dice0 /= len(train_loader) train_dice1 /= len(train_loader) train_dice2 /= len(train_loader) print( 'Train Epoch: {} \tLoss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}' .format(epoch, train_loss, train_dice0, train_dice1, train_dice2)) logger.scalar_summary('train_loss', float(train_loss), epoch) logger.scalar_summary('train_dice0', float(train_dice0), epoch) logger.scalar_summary('train_dice1', float(train_dice1), epoch) logger.scalar_summary('train_dice2', float(train_dice2), epoch)
def val(model, val_loader, loss_func, n_labels): model.eval() val_loss = metrics.LossAverage() val_dice = metrics.DiceAverage(n_labels) with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): data, target = data.float(), target.long() target = common.to_one_hot_3d(target, n_labels) data, target = data.to(device), target.to(device) output = model(data) loss = loss_func(output, target) val_loss.update(loss.item(), data.size(0)) val_dice.update(output, target) val_log = OrderedDict({ 'Val_Loss': val_loss.avg, 'Val_dice_liver': val_dice.avg[1] }) if n_labels == 3: val_log.update({'Val_dice_tumor': val_dice.avg[2]}) return val_log
def train(model, train_loader, optimizer): print("=======Epoch:{}=======".format(epoch)) model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): target = common.to_one_hot_3d(target.long()) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) # loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) # loss = metrics.DiceMeanLoss()(output, target) loss = metrics.WeightDiceLoss()(output, target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss += float(loss) train_dice0 += float(metrics.dice(output, target, 0)) train_dice1 += float(metrics.dice(output, target, 1)) train_dice2 += float(metrics.dice(output, target, 2)) train_loss /= len(train_loader) train_dice0 /= len(train_loader) train_dice1 /= len(train_loader) train_dice2 /= len(train_loader) return OrderedDict({ 'Train Loss': train_loss, 'Train dice0': train_dice0, 'Train dice1': train_dice1, 'Train dice2': train_dice2 })
def train(model, train_loader, optimizer, criterion, n_labels): print("=======Epoch:{}=======lr:{}".format( epoch, optimizer.state_dict()['param_groups'][0]['lr'])) model.train() train_loss = metrics.LossAverage() train_dice = metrics.DiceAverage(n_labels) for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): data, target = data.float(), target.long() target = common.to_one_hot_3d(target, n_labels) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) # if idx==0: # print(output.shape) loss = criterion(output, target) loss.backward() optimizer.step() train_loss.update(loss.item(), data.size(0)) train_dice.update(output, target) if n_labels == 2: return OrderedDict({ 'Train Loss': train_loss.avg, 'Train dice0': train_dice.avg[0], 'Train dice1': train_dice.avg[1] }) else: return OrderedDict({ 'Train Loss': train_loss.avg, 'Train dice0': train_dice.avg[0], 'Train dice1': train_dice.avg[1], 'Train dice2': train_dice.avg[2] })
def val(model, val_loader, epoch, logger): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): target = common.to_one_hot_3d(target.long()) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) logger.scalar_summary('val_loss', val_loss, epoch) logger.scalar_summary('val_dice0', val_dice0, epoch) logger.scalar_summary('val_dice1', val_dice1, epoch) logger.scalar_summary('val_dice2', val_dice2, epoch) print( 'Val performance: Average loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n' .format(val_loss, val_dice0, val_dice1, val_dice2))