Exemplo n.º 1
0
def eval_net_full(net, loader, device, ratio):
    # Evaluation of full image
    tot = 0
    loss = 0
    dataset = loader.dataset.dataset
    num_images = int(ratio * dataset.get_real_length())
    image_idx = np.random.choice(np.arange(0, dataset.get_real_length()),
                                 num_images)

    with tqdm(total=num_images,
              desc="Full Validation round",
              unit="img",
              leave=False) as pbar:
        for i in image_idx:
            img, true_mask = dataset.get_raw_image(i), dataset.get_raw_mask(i)
            prediction = predict_full_image(net, img, device)
            prediction = torch.from_numpy(prediction).to(device=device).float()
            true_mask = torch.from_numpy(np.expand_dims(
                true_mask, 0)).to(device=device).float()

            tot += dice_coeff(((prediction > 0.3) * 1).float(),
                              true_mask).item()
            loss += dice_loss(prediction, true_mask).item()
            pbar.update(i)
    return tot / num_images, loss / num_images, torch.from_numpy(
        img.transpose((2, 0, 1))), true_mask, prediction
Exemplo n.º 2
0
def eval_net(net, loader, device, n_val):
    """Evaluation with the dice coefficient and Dice loss"""
    tot = 0
    loss = 0
    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
        for batch in loader:
            imgs = batch['image']
            true_masks = batch['mask']
            
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=torch.float32)
            
            mask_pred = net(imgs)
            for true_mask, pred in zip(true_masks, mask_pred):
                pred = ((torch.sigmoid(pred) > 0.3) * 1).float()
                
                tot += dice_coeff(pred, true_mask).item()
                loss += dice_loss(pred, true_mask).item()
            pbar.update(imgs.shape[0])
    
    return tot / n_val, loss / n_val
Exemplo n.º 3
0
def train_model(epochs, criterion, optimizer, lr_scheduler, net, train_loader,
                val_loader, dir_checkpoint, logger, n_train, n_val, batch_size,
                writer, val_ratio, balance_classes):
    # torch.multiprocessing.set_start_method('spawn')
    # Register device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device {device}')
    # Create the Network

    net.to(device=device)

    dataset_length = n_val + n_train
    global_step = 0

    for epoch in range(epochs):
        net.train()  # Sets module in training mode
        epoch_loss = []
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                if balance_classes:
                    # Neg / pos to rectify class imbalance
                    pos_weight = torch.sum(
                        torch.abs(true_masks - 1)) / torch.sum(true_masks)
                    criterion.pos_weight = torch.tensor([pos_weight]).to(
                        device=device, dtype=torch.float32)
                # Optimization step
                optimizer.zero_grad()
                masks_pred = net(imgs)  # Make predictions
                loss = criterion(masks_pred, true_masks)  # Evaluate loss
                batch_loss = loss.item()
                loss.backward()
                optimizer.step()

                # Add data to tensorboard
                epoch_loss.append(batch_loss)  # Add loss to epoch
                writer.add_scalar('Train/BCE_loss', batch_loss, global_step)
                d_loss = dice_loss(torch.sigmoid(masks_pred), true_masks)
                writer.add_scalar('Train/Dice_loss', d_loss, global_step)
                pbar.set_postfix(**{'loss (batch)': batch_loss})
                pbar.update(imgs.shape[0])

                global_step += 1

                # Validation every 10 batches
                if global_step % (dataset_length //
                                  (10 * batch_size)) == 0 and n_val > 0:
                    net.eval()
                    val_score, val_loss = eval_net(net, val_loader, device,
                                                   n_val)
                    net.train()  # Reset in training mode

                    logger.info('Validation Dice Coeff: {}'.format(val_score))
                    writer.add_scalar('Validation/Dice_coef', val_score,
                                      global_step)
                    writer.add_scalar('Validation/Dice_loss', val_loss,
                                      global_step)
                    writer.add_images('images', imgs, global_step)
                    writer.add_images('masks/true', true_masks, global_step)
                    writer.add_images('masks/pred', torch.sigmoid(masks_pred),
                                      global_step)

                    # if lr_scheduler is not None:
                    #     lr_scheduler.step(int(val_loss * 1000))
                    #     writer.add_scalar("LR", get_lr(optimizer), global_step)

                if global_step % 300 == 0 and n_val > 0:
                    net.eval()
                    val_full_score, val_full_loss, img, true_mask, mask_pred = eval_net_full(
                        net, val_loader, device, val_ratio)
                    net.train()

                    logger.info('Full Validation Dice Coeff: {}'.format(
                        val_full_score))
                    writer.add_scalar('Full_Validation/Dice_coef',
                                      val_full_score, global_step)
                    writer.add_scalar('Full_Validation/Dice_loss',
                                      val_full_loss, global_step)

                    writer.add_images('full_images', img[None, :, :, :],
                                      global_step)
                    writer.add_images('full_masks/true',
                                      true_mask[None, :, :, :], global_step)
                    writer.add_images('full_masks/pred',
                                      mask_pred[None, :, :, :], global_step)

                if (global_step + 1) % SAVE_EVERY == 0:
                    torch.save(net.state_dict(),
                               dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
                    logger.info(f'Checkpoint {epoch + 1} saved !')

        if lr_scheduler is not None:
            ep_loss = int(np.mean(epoch_loss) * 1000)
            lr_scheduler.step(ep_loss)
            writer.add_scalar("LR/epoch_loss", epoch_loss)
            writer.add_scalar("LR", get_lr(optimizer), global_step)

    writer.close()
    torch.save(net.state_dict(), os.path.join(dir_checkpoint, "final.pth"))