Ejemplo n.º 1
0
 def test_dice_score(self):
     pairwise_measures = PairwiseMeasures(seg_img=TEST_CASES[0]['seg_img'],
                                          ref_img=TEST_CASES[0]['ref_img'])
     self.assertEqual(pairwise_measures.dice_score(), 1.0)
Ejemplo n.º 2
0
def train(dsets, model, criterion, optimizer, num_epochs, device, cp_path,
          batch_size):
    since = time.time()

    dataloaders = {
        x: DataLoader(dsets[x],
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=4)
        for x in ['training', 'validation']
    }

    model = model.to(device)

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['training', 'validation']:
            if phase == 'training':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            epoch_samples = 0

            # Iterate over data
            for iteration, (inputs,
                            labels) in enumerate(dataloaders[phase], 1):

                nbatches, wsize, nchannels, x, y, z, _ = inputs.size()

                inputs = inputs.view(nbatches * wsize, nchannels, x, y, z)
                labels = labels.view(nbatches * wsize, nchannels, x, y, z)

                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'training'):
                    outputs = model(inputs)
                    pred = (outputs > 0.5)

                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'training':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                running_loss += loss.item() * inputs.size(0)
                measures = PairwiseMeasures(pred.cpu().numpy(),
                                            labels.cpu().numpy())
                running_corrects += measures.dice_score() * inputs.size(0)

            epoch_loss = running_loss / epoch_samples

            epoch_acc = running_corrects / epoch_samples

            print('{} Loss: {:.4f} Dice: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if epoch == 0:
                best_loss = epoch_loss
                torch.save(model.state_dict(), cp_path.format(epoch + 1))

            # deep copy the model
            if phase == 'validation' and epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(), cp_path)
                print('Checkpoint {} saved!'.format(epoch + 1))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Ejemplo n.º 3
0
 def test_dice_score(self):
     pairwise_measures = PairwiseMeasures(seg_img=TEST_CASES[0]['seg_img'],
                                          ref_img=TEST_CASES[0]['ref_img'])
     self.assertEqual(pairwise_measures.dice_score(), 1.0)