def test(args, model, device, test_loader, meters, epoch, criterion): testloss = meters['loss'] testdice = meters['dice'] model.eval() test_loss = 0 with torch.no_grad(): for batch_idx, (data, mask) in enumerate(test_loader): data = data.unsqueeze(1).float() mask = mask.unsqueeze(1).float() data, mask = data.to(device), mask.to(device) output = model(data) loss = criterion(output, mask) #loss = F.binary_cross_entropy_with_logits(output, mask, reduction='sum').item() test_loss += loss dice = dice_coefficient(output, mask) testdice.update(dice) testloss.update(loss) info = {'test_loss': loss, 'test_dice': testdice.avg} for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) test_loss /= len(test_loader.dataset) print( '\nTest set: Average loss: {:.4f}, Average Dice Coefficient: {:.6f}\n'. format(testloss.avg, testdice.avg))
def train(args, model, device, train_loader, optimizer, epoch, meters, criterion, savepath=None, savefile=None): trainloss = meters['loss'] traindice = meters['dice'] model.train() for batch_idx, (data, mask) in enumerate(train_loader): data = data.unsqueeze(1).float() mask = mask.unsqueeze(1).float() data, mask = data.to(device), mask.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, mask) #loss = F.binary_cross_entropy_with_logits(output, mask) dice = dice_coefficient(output, mask) loss.backward() optimizer.step() trainloss.update(loss.item()) traindice.update(dice) if batch_idx % args.log_interval == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Dice: {:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), traindice.avg)) info = {'train_loss': loss.item(), 'train_dice': traindice.avg} for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.histo_summary(tag, value.data.cpu().numpy(), epoch) logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), epoch) imgs = output.squeeze(1) imgs = output.view(-1, 512, 512)[:2].detach().cpu().numpy() info = {'segmentations': imgs} for tag, images in info.items(): logger.image_summary(tag, images, epoch) if args.checkpoint: save_checkpoint(model, optimizer, epoch, loss, savepath, savefile)
def train(args, model, start_gpu, end_gpu, train_loader, optimizer, epoch, meters, criterion): trainloss = meters['loss'] traindice = meters['dice'] model.train() for batch_idx, (data, mask) in enumerate(train_loader): data = data.unsqueeze(1).float() mask = mask.unsqueeze(1).float() data = downsample_img(data) mask = downsample_mask(mask) data, mask = data.to(start_gpu), mask.to(end_gpu) optimizer.zero_grad() output = model(data) #loss = F.binary_cross_entropy_with_logits(output, mask, reduction='mean') output = torch.sigmoid(output) loss = criterion(output, mask) with torch.no_grad(): output_binary = output > 0.5 dice = dice_coefficient(output_binary.float(), mask) loss.backward() optimizer.step() trainloss.update(loss.item()) traindice.update(dice) if batch_idx % args.log_interval == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Dice: {:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), traindice.avg)) info = {'train_loss': loss.item(), 'train_dice': traindice.avg} for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.histo_summary(tag, value.data.cpu().numpy(), epoch) logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), epoch) imgs = output.view(-1, 256, 256)[:2].detach().cpu().numpy() info = {'segmentations': imgs} for tag, images in info.items(): logger.image_summary(tag, images, epoch)