Пример #1
0
 def test_dice_loss(self):
     results = _compute_criterion(DiceLoss())
     # check that all of the coefficients belong to [0, 1]
     results = np.array(results)
     assert np.all(results > 0)
     assert np.all(results < 1)
Пример #2
0
def train_net(model: UNet3D,
              device,
              loss_fnc=DiceLoss(sigmoid_normalization=False),
              eval_criterion=MeanIoU(),
              epochs=5,
              batch_size=1,
              learning_rate=0.0002,
              val_percent=0.04,
              test_percent=0.1,
              name='U-Net',
              save_cp=True,
              tests=None):
    data_set = BasicDataset(dir_img, dir_mask, 'T1', device)
    train_loader, val_loader, test_loader = data_set.split_to_loaders(
        val_percent, test_percent, batch_size, test_files=tests)

    writer = SummaryWriter(comment=f'LR_{learning_rate}_BS_{batch_size}')
    global_step = 0
    logging.info(f'''Starting {name} training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {len(train_loader)}
        Validation size: {len(val_loader)}
        Testing size:    {len(test_loader)}
        Checkpoints:     {save_cp}
        Device:          {device.type}
    ''')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=0.00001)
    losses = []
    val_scores = []

    for epoch in range(epochs):

        epoch_loss = 0
        for batch in train_loader:
            model.train()
            start_time = timeit.default_timer()

            img = batch['image']
            mask = batch['mask']

            masks_pred = model(img)

            loss = loss_fnc(masks_pred, mask)

            epoch_loss += loss.item()
            losses.append(loss.item())

            writer.add_scalar('Loss/train', loss.item(), global_step)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            global_step += 1
            elapsed = timeit.default_timer() - start_time
            logging.info(
                f'I: {global_step}, Loss: {loss.item()} in {elapsed} seconds')

            if global_step % (len(train_loader) // (5 * batch_size)) == 0:
                val_score = validate(model, val_loader, loss_fnc,
                                     eval_criterion)
                val_scores.append(val_score)

                writer.add_scalar('Validation/test', val_score, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(model.state_dict(),
                       dir_checkpoint + f'{name}_epoch{epoch + 1}.pth')
            logging.info(f'Epoch: {epoch + 1} Loss: {epoch_loss}')
            logging.info(f'Checkpoint {epoch + 1} saved !')
            plot_cost(losses, name='Loss' + str(epoch), model_name=name)
            plot_cost(val_scores,
                      name='Validation' + str(epoch),
                      model_name=name)

    writer.close()