def test(model, test_loader): print("Evaluation of Testset Starting...") model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for data, target in tqdm(test_loader): data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(test_loader) val_dice0 /= len(test_loader) val_dice1 /= len(test_loader) val_dice2 /= len(test_loader) print('\nTest set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format( val_loss, val_dice0, val_dice1, val_dice2))
def val(model, val_loader): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): target = common.to_one_hot_3d(target.long()) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) return OrderedDict({ 'Val Loss': val_loss, 'Val dice0': val_dice0, 'Val dice1': val_dice1, 'Val dice2': val_dice2 })
def test(model, dataset, save_path, filename): dataloader = DataLoader(dataset=dataset, batch_size=4, num_workers=0, shuffle=False) model.eval() save_tool = Recompone_tool(save_path, filename, dataset.ori_shape, dataset.new_shape, dataset.cut) target = torch.from_numpy(np.expand_dims(dataset.label_np, axis=0)).long() target = to_one_hot_3d(target) with torch.no_grad(): for data in tqdm(dataloader, total=len(dataloader)): data = data.unsqueeze(1) data = data.float().to(device) output = model(data) save_tool.add_result(output.detach().cpu()) pred = save_tool.recompone_overlap() pred = torch.unsqueeze(pred, dim=0) val_loss = metrics.DiceMeanLoss()(pred, target) val_dice0 = metrics.dice(pred, target, 0) val_dice1 = metrics.dice(pred, target, 1) val_dice2 = metrics.dice(pred, target, 2) pred_img = torch.argmax(pred, dim=1) img = sitk.GetImageFromArray( np.squeeze(np.array(pred_img.numpy(), dtype='uint8'), axis=0)) sitk.WriteImage(img, os.path.join(save_path, filename)) # save_tool.save(filename) print( '\nAverage loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n' .format(val_loss, val_dice0, val_dice1, val_dice2)) return val_loss, val_dice0, val_dice1, val_dice2
def val(model, val_loader, epoch, logger): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for data, target in val_loader: data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) logger.scalar_summary('val_loss', val_loss, epoch) logger.scalar_summary('val_dice0', val_dice0, epoch) logger.scalar_summary('val_dice1', val_dice1, epoch) logger.scalar_summary('val_dice2', val_dice2, epoch) print('\nVal set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format( val_loss, val_dice0, val_dice1, val_dice2))
def train(model, train_loader, optimizer, epoch, logger): print("=======Epoch:{}=======".format(epoch)) model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) # loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) loss = metrics.DiceMeanLoss()(output, target) # loss=metrics.WeightDiceLoss()(output,target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss += loss train_dice0 += metrics.dice(output, target, 0) train_dice1 += metrics.dice(output, target, 1) train_dice2 += metrics.dice(output, target, 2) train_loss /= len(train_loader) train_dice0 /= len(train_loader) train_dice1 /= len(train_loader) train_dice2 /= len(train_loader) print( 'Train Epoch: {} \tLoss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}' .format(epoch, train_loss, train_dice0, train_dice1, train_dice2)) logger.scalar_summary('train_loss', float(train_loss), epoch) logger.scalar_summary('train_dice0', float(train_dice0), epoch) logger.scalar_summary('train_dice1', float(train_dice1), epoch) logger.scalar_summary('train_dice2', float(train_dice2), epoch)
def train(model, train_loader): print("=======Epoch:{}=======".format(epoch)) model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) # loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) loss = metrics.DiceMeanLoss()(output, target) # loss=metrics.WeightDiceLoss()(output,target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss += float(loss) train_dice0 += float(metrics.dice(output, target, 0)) train_dice1 += float(metrics.dice(output, target, 1)) train_dice2 += float(metrics.dice(output, target, 2)) train_loss /= len(train_loader) train_dice0 /= len(train_loader) train_dice1 /= len(train_loader) train_dice2 /= len(train_loader) return OrderedDict({ 'Train Loss': train_loss, 'Train dice0': train_dice0, 'Train dice1': train_dice1, 'Train dice2': train_dice2 })
def train(model, train_loader, optimizer, epoch, logger): model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 for batch_idx, (data, target) in enumerate(train_loader): data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) # loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) loss = metrics.DiceMeanLoss()(output, target) # loss=metrics.WeightDiceLoss()(output,target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss = loss train_dice0 = metrics.dice(output, target, 0) train_dice1 = metrics.dice(output, target, 1) train_dice2 = metrics.dice(output, target, 2) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tdice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tT: {:.6f}\tP: {:.6f}\tTP: {:.6f}' .format(epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss.item(), train_dice0, train_dice1, train_dice2, metrics.T(output, target), metrics.P(output, target), metrics.TP(output, target))) logger.scalar_summary('train_loss', float(train_loss), epoch) logger.scalar_summary('train_dice0', float(train_dice0), epoch) logger.scalar_summary('train_dice1', float(train_dice1), epoch) logger.scalar_summary('train_dice2', float(train_dice2), epoch)
def val(model, val_loader, epoch, logger): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 with torch.no_grad(): for idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)): data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) logger.scalar_summary('val_loss', val_loss, epoch) logger.scalar_summary('val_dice0', val_dice0, epoch) logger.scalar_summary('val_dice1', val_dice1, epoch) logger.scalar_summary('val_dice2', val_dice2, epoch) print( 'Val performance: Average loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n' .format(val_loss, val_dice0, val_dice1, val_dice2))