示例#1
0
def predict(net, full_img, device, input_size, mask_way='warp'):
    '''
    :mask_type: Sets the way to obtain the mask. Сan take 'warp' or 'segm'
    '''

    # Preprocess input image:
    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    net.eval()

    # Predict:
    with torch.no_grad():
        logits, rec_mask, theta = net.predict(
            img, warp=True if mask_way == 'warp' else False)

    if mask_way == 'warp':
        mask = rec_mask * net.n_classes
        mask = mask.type(torch.IntTensor).cpu().numpy().astype(np.uint8)
    elif mask_way == 'segm':
        mask = preds_to_masks(logits, net.n_classes)
    else:
        raise NotImplementedError

    return mask, theta
示例#2
0
def predict_img(net, full_img, device, input_size):
    net.eval()

    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        preds = net(img)
        masks = preds_to_masks(preds, net.n_classes)  # GPU tensor -> CPU numpy

    return masks
def predict_img(net, full_img, device, input_size):
    # Preprocess input image:
    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    net.eval()

    # Predict:
    with torch.no_grad():
        mask_pred, mask_proj = net(img)

    # Tensors to ndarrays:
    mask = preds_to_masks(mask_pred, net.n_classes)
    proj = mask_proj * net.n_classes
    proj = proj.type(torch.IntTensor).cpu().numpy().astype(np.uint8)

    return mask, proj
示例#4
0
def train_net(net,
              device,
              img_dir,
              mask_dir,
              val_names,
              num_classes,
              cp_dir=None,
              log_dir=None,
              epochs=5,
              batch_size=1,
              lr=0.001,
              target_size=(1280, 720),
              vizualize=False):
    '''
    Train U-Net model
    '''
    # Prepare dataset:
    train_ids, val_ids = split_on_train_val(img_dir, val_names)
    train = BasicDataset(train_ids, img_dir, mask_dir, num_classes,
                         target_size)
    val = BasicDataset(val_ids, img_dir, mask_dir, num_classes, target_size)
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    n_train = len(train)
    n_val = len(val)

    writer = SummaryWriter(
        log_dir=log_dir,
        comment=
        f'LR_{lr}_BS_{batch_size}_SIZE_{target_size}_DECONV_{net.bilinear}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Images dir:      {img_dir}
        Masks dir:       {mask_dir}
        Checkpoints dir: {cp_dir}
        Log dir:         {log_dir}
        Device:          {device.type}
        Input size:      {target_size}
        Vizualize:       {vizualize}
    ''')

    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (5 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)

                    # Validation:
                    result = eval_net(net,
                                      val_loader,
                                      device,
                                      verbose=vizualize)
                    val_score = result['val_score']
                    scheduler.step(val_score)

                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)
                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    if vizualize:
                        # Postprocess predicted mask for tensorboard vizualization:
                        pred_masks = preds_to_masks(result['preds'],
                                                    net.n_classes)
                        pred_masks = mask_to_image(pred_masks)
                        pred_masks = np.transpose(pred_masks, (0, 3, 1, 2))
                        pred_masks = pred_masks.astype(np.float32) / 255.0
                        pred_masks = pred_masks[..., ::-1]

                        # Save the results for tensorboard vizualization:
                        writer.add_images('imgs', result['imgs'], global_step)
                        writer.add_images('preds', pred_masks, global_step)

        if cp_dir is not None:
            try:
                os.mkdir(cp_dir)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(), cp_dir + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
示例#5
0
def train_net(net, device, img_dir, mask_dir, val_names,  num_classes, opt='RMSprop',
              aug=None, cp_dir=None, log_dir=None, epochs=5, batch_size=1,
              lr=0.0001, w_decay=1e-8, target_size=(1280,720), vizualize=False):
    '''
    Train U-Net model
    '''
    # Prepare dataset:nvidi
    train_ids, val_ids = split_on_train_val(img_dir, val_names)
    train = BasicDataset(train_ids, img_dir, mask_dir, num_classes, target_size, aug=aug)
    val = BasicDataset(val_ids, img_dir, mask_dir, num_classes, target_size)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8,
                              pin_memory=True, worker_init_fn=worker_init_fn)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8,
                            pin_memory=True, drop_last=True)
    n_train = len(train)
    n_val = len(val)

    writer = SummaryWriter(log_dir=log_dir,
                           comment=f'LR_{lr}_BS_{batch_size}_SIZE_{target_size}_DECONV_{net.bilinear}')
    global_step = 0

    logging.info(f'''Starting training:
        Optimizer:       {opt}
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Weight decay:    {w_decay}
        Training size:   {n_train}
        Validation size: {n_val}
        Images dir:      {img_dir}
        Masks dir:       {mask_dir}
        Checkpoints dir: {cp_dir}
        Log dir:         {log_dir}
        Device:          {device.type}
        Input size:      {target_size}
        Vizualize:       {vizualize}
        Augmentation:    {aug}
    ''')

    if opt == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=w_decay, momentum=0.9)
    elif opt == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=w_decay, momentum=0.9)
    elif opt == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=w_decay)
    else:
        print ('optimizer {} does not support yet'.format(opt))
        raise NotImplementedError

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=3)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device)

                # #######
                # from torchvision import transforms
                # for j, (img, mask) in enumerate(zip(imgs,true_masks)):
                #     img = transforms.ToPILImage(mode='RGB')(img)
                #     mask = transforms.ToPILImage(mode='L')(mask.to(dtype=torch.uint8))
                #
                #     # Save:
                #     dst_dir = '/media/darkalert/c02b53af-522d-40c5-b824-80dfb9a11dbb/boost/datasets/court_segmentation/NCAAM_2/TEST/'
                #     dst_path = os.path.join(dst_dir, '{}-{}.jpeg'.format(global_step,j))
                #     img.save(dst_path, 'JPEG')
                #     dst_path = os.path.join(dst_dir, '{}-{}_mask.png'.format(global_step,j))
                #     mask.save(dst_path, 'PNG')
                # #######

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (5 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)

                    # Validation:
                    result = eval_net(net, val_loader, device, verbose=vizualize)
                    val_score = result['val_score']
                    scheduler.step(val_score)

                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    if vizualize:
                        # Postprocess predicted mask for tensorboard vizualization:
                        pred_masks = preds_to_masks(result['preds'], net.n_classes)
                        pred_masks = onehot_to_image(pred_masks, num_classes)
                        pred_masks = np.transpose(pred_masks, (0, 3, 1, 2))
                        pred_masks = pred_masks.astype(np.float32) / 255.0
                        pred_masks = pred_masks[...,::-1]

                        # Save the results for tensorboard vizualization:
                        writer.add_images('imgs', result['imgs'], global_step)
                        writer.add_images('preds', pred_masks, global_step)

        if cp_dir is not None:
            try:
                os.mkdir(cp_dir)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       cp_dir + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
示例#6
0
def train_net(net,
              device,
              img_dir,
              mask_dir,
              val_names,
              num_classes,
              seg_loss,
              rec_loss,
              seg_lambda,
              rec_lambda,
              poi_loss,
              poi_lambda,
              anno_dir=None,
              anno_keys=None,
              opt='RMSprop',
              aug=None,
              cp_dir=None,
              log_dir=None,
              epochs=5,
              batch_size=1,
              lr=0.0001,
              w_decay=1e-8,
              target_size=(1280, 720),
              only_ncaam=False,
              vizualize=False):
    '''
    Train UNet+UNetReg+ResNetReg model
    '''
    # Prepare dataset:
    train_ids, val_ids = split_on_train_val(img_dir,
                                            val_names,
                                            only_ncaam=only_ncaam)
    train = BasicDataset(train_ids,
                         img_dir,
                         mask_dir,
                         anno_dir,
                         anno_keys,
                         num_classes,
                         target_size,
                         aug=aug)
    val = BasicDataset(val_ids, img_dir, mask_dir, anno_dir, anno_keys,
                       num_classes, target_size)
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    n_train = len(train)
    n_val = len(val)

    # Logger:
    writer = SummaryWriter(
        log_dir=log_dir,
        comment=
        f'LR_{lr}_BS_{batch_size}_SIZE_{target_size}_DECONV_{net.bilinear}')
    logging.info(f'''Starting training:
        Optimizer:       {opt}
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Weight decay:    {w_decay}
        Segmentation:    {seg_loss}
        Reconstruction:  {rec_loss}
        PoI loss:        {poi_loss}
        Seg Lambda:      {seg_lambda}
        Rec Lambda:      {rec_lambda}
        PoI Lambda:      {poi_lambda}
        Training size:   {n_train}
        Validation size: {n_val}
        Images dir:      {img_dir}
        Masks dir:       {mask_dir}
        Annotation dir:  {anno_dir}
        Annotation keys: {anno_keys}
        Checkpoints dir: {cp_dir}
        Log dir:         {log_dir}
        Device:          {device.type}
        Input size:      {target_size}
        Vizualize:       {vizualize}
        Augmentation:    {aug}
    ''')

    # Oprimizer:
    if opt == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(),
                                  lr=lr,
                                  weight_decay=w_decay,
                                  momentum=0.9)
    elif opt == 'SGD':
        optimizer = optim.SGD(net.parameters(),
                              lr=lr,
                              weight_decay=w_decay,
                              momentum=0.9)
    elif opt == 'Adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=lr,
                               betas=(0.9, 0.999),
                               weight_decay=w_decay)
    else:
        print('optimizer {} does not support yet'.format(opt))
        raise NotImplementedError

    # Scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=3)

    # Losses:
    if seg_loss == 'CE':
        criterion = nn.CrossEntropyLoss()
    elif seg_loss == 'focal':
        criterion = kornia.losses.FocalLoss(alpha=1.0,
                                            gamma=2.0,
                                            reduction='mean')
    else:
        raise NotImplementedError

    if rec_loss == 'MSE':
        rec_criterion = nn.MSELoss()
    elif rec_loss == 'SmoothL1':
        rec_criterion = nn.SmoothL1Loss()
    else:
        raise NotImplementedError

    if poi_loss == 'MSE':
        poi_criterion = nn.MSELoss()
    elif poi_loss == 'SmoothL1':
        poi_criterion = nn.SmoothL1Loss()
    else:
        raise NotImplementedError

    global_step = 0

    # Training loop:
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                # Get data:
                imgs = batch['image']
                true_masks = batch['mask']
                gt_poi = batch['poi'] if 'poi' in batch else None
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                # CPU -> GPU:
                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device)
                if gt_poi is not None:
                    gt_poi = gt_poi.to(device=device, dtype=torch.float32)

                # Forward:
                logits, rec_masks, poi = net(imgs)

                # Caluclate CrossEntropy loss:
                seg_loss = criterion(logits, true_masks) * seg_lambda

                # Calculate reconstruction loss for regressors:
                gt_masks = true_masks.to(
                    dtype=torch.float32) / float(num_classes)
                rec_loss = rec_criterion(rec_masks, gt_masks) * rec_lambda

                # Total loss:
                loss = seg_loss + rec_loss

                # Log:
                writer.add_scalar('Loss/train', loss.item(), global_step)
                writer.add_scalar('Loss/train seg', seg_loss.item(),
                                  global_step)
                writer.add_scalar('Loss/train rec', rec_loss.item(),
                                  global_step)
                logs = {
                    'Seg_loss': seg_loss.item(),
                    'Rec_loss': rec_loss.item(),
                    'Tot loss': loss.item()
                }

                # Calculate a regression loss for PoI:
                # if gt_poi is not None:
                #     poi_loss = poi_criterion(poi, gt_poi) * poi_lambda
                #     loss += poi_loss
                #     writer.add_scalar('Loss/train PoI', poi_loss.item(), global_step)
                #     logs['PoI_loss'] = poi_loss.item()

                epoch_loss += loss.item()
                pbar.set_postfix(**logs)

                # Backward:
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                # Validation step:
                if global_step % (n_train // (5 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        t = tag.replace('.', '/')
                        writer.add_histogram('weights/' + t,
                                             value.data.cpu().numpy(),
                                             global_step)
                        if value.grad is not None:
                            writer.add_histogram('grads/' + t,
                                                 value.grad.data.cpu().numpy(),
                                                 global_step)

                    # Evaluate:
                    result = eval_reconstructor(net,
                                                val_loader,
                                                device,
                                                verbose=vizualize)
                    val_ce_score = result['val_seg_score']
                    val_rec_score = result['val_rec_score']
                    val_tot_score = val_ce_score + val_rec_score
                    scheduler.step(val_tot_score)

                    # Validation log:
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)
                    writer.add_scalar('Loss/test', val_tot_score, global_step)
                    writer.add_scalar('Loss/test_seg', val_ce_score,
                                      global_step)
                    writer.add_scalar('Loss/test_rec', val_rec_score,
                                      global_step)
                    logging.info(
                        '\nValidation tot: {}, seg: {}, rec: {}'.format(
                            val_tot_score, val_ce_score, val_rec_score))

                    if vizualize:
                        # Postprocess predicted mask for tensorboard vizualization:
                        pred_masks = preds_to_masks(result['logits'],
                                                    net.n_classes)
                        pred_masks = onehot_to_image(pred_masks, num_classes)
                        pred_masks = pred_masks[..., ::-1]  # rgb to bgr
                        pred_masks = np.transpose(pred_masks, (0, 3, 1, 2))
                        pred_masks = pred_masks.astype(np.float32) / 255.0

                        rec_masks = result['rec_masks'] * num_classes
                        rec_masks = rec_masks.type(
                            torch.IntTensor).cpu().numpy().astype(np.uint8)
                        rec_masks = onehot_to_image(rec_masks, num_classes)
                        rec_masks = rec_masks[..., ::-1]
                        rec_masks = np.transpose(rec_masks, (0, 3, 1, 2))
                        rec_masks = rec_masks.astype(np.float32) / 255.0

                        # Concatenate all images:
                        output = np.concatenate(
                            (result['imgs'], pred_masks, rec_masks), axis=2)

                        # Save the results for tensorboard vizualization:
                        writer.add_images('output', output, global_step)

                        # import cv2
                        # output2 = np.transpose(output, (0, 2, 3, 1)) * 255.0
                        # for ii, out_img in enumerate(output2):
                        #     out_img = cv2.cvtColor(out_img, cv2.COLOR_BGR2RGB)
                        #     out_path = '/media/darkalert/c02b53af-522d-40c5-b824-80dfb9a11dbb/boost/datasets/court_segmentation/NCAA2020+_dev/test/' + str(ii) + '.png'
                        #     cv2.imwrite(out_path, out_img)

        # Save checkpoint:
        if cp_dir is not None:
            try:
                os.mkdir(cp_dir)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(), cp_dir + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
示例#7
0
def train_net(net,
              device,
              img_dir,
              mask_dir,
              val_names,
              num_classes,
              opt='RMSprop',
              aug=None,
              cp_dir=None,
              log_dir=None,
              epochs=5,
              batch_size=1,
              lr=0.0001,
              w_decay=1e-8,
              l2_lymbda=10.0,
              target_size=(1280, 720),
              vizualize=False):
    '''
    Train U-Net model
    '''
    # Prepare dataset:nvidi
    train_ids, val_ids = split_on_train_val(img_dir, val_names)
    train = BasicDataset(train_ids,
                         img_dir,
                         mask_dir,
                         num_classes,
                         target_size,
                         aug=aug)
    val = BasicDataset(val_ids, img_dir, mask_dir, num_classes, target_size)
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    n_train = len(train)
    n_val = len(val)

    writer = SummaryWriter(
        log_dir=log_dir,
        comment=
        f'LR_{lr}_BS_{batch_size}_SIZE_{target_size}_DECONV_{net.bilinear}')
    global_step = 0

    logging.info(f'''Starting training:
        Optimizer:       {opt}
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Weight decay:    {w_decay}
        L2 lambda:       {l2_lymbda}
        Training size:   {n_train}
        Validation size: {n_val}
        Images dir:      {img_dir}
        Masks dir:       {mask_dir}
        Checkpoints dir: {cp_dir}
        Log dir:         {log_dir}
        Device:          {device.type}
        Input size:      {target_size}
        Vizualize:       {vizualize}
        Augmentation:    {aug}
    ''')

    if opt == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(),
                                  lr=lr,
                                  weight_decay=w_decay,
                                  momentum=0.9)
    elif opt == 'SGD':
        optimizer = optim.SGD(net.parameters(),
                              lr=lr,
                              weight_decay=w_decay,
                              momentum=0.9)
    elif opt == 'Adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=lr,
                               betas=(0.9, 0.999),
                               weight_decay=w_decay)
    else:
        print('optimizer {} does not support yet'.format(opt))
        raise NotImplementedError

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=3)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()
    l2_criterion = nn.MSELoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device)

                masks_pred, projected_masks = net(imgs)

                # Caluclate CrossEntropy loss:
                ce_loss = criterion(masks_pred, true_masks)

                # Calculate reconstruction L2-loss for STN:
                gt_masks = true_masks.to(
                    dtype=torch.float32) / float(num_classes)
                l2_loss = l2_criterion(projected_masks, gt_masks) * l2_lymbda

                # Total loss:
                loss = ce_loss + l2_loss
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                writer.add_scalar('Loss/train CE', ce_loss.item(), global_step)
                writer.add_scalar('Loss/train L2', l2_loss.item(), global_step)
                pbar.set_postfix(
                    **{
                        'loss (batch)': loss.item(),
                        'CE_loss': ce_loss.item(),
                        'L2_loss': l2_loss.item()
                    })

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                if global_step % (n_train // (5 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)

                    # Validation:
                    result = eval_stn(net,
                                      val_loader,
                                      device,
                                      verbose=vizualize)
                    val_tot_score = result['val_tot_score']
                    val_ce_score = result['val_ce_score']
                    val_mse_score = result['val_mse_score']
                    scheduler.step(val_tot_score)

                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)
                    writer.add_scalar('Loss/test', val_tot_score, global_step)
                    writer.add_scalar('Loss/test_mse', val_mse_score,
                                      global_step)
                    writer.add_scalar('Loss/test_ce', val_ce_score,
                                      global_step)
                    logging.info(
                        'Validation total: {}, CE: {}, MSE: {}'.format(
                            val_tot_score, val_ce_score, val_mse_score))

                    if vizualize:
                        # Postprocess predicted mask for tensorboard vizualization:
                        pred_masks = preds_to_masks(result['preds'],
                                                    net.n_classes)
                        pred_masks = onehot_to_image(pred_masks, num_classes)
                        pred_masks = np.transpose(pred_masks, (0, 3, 1, 2))
                        pred_masks = pred_masks.astype(np.float32) / 255.0
                        pred_masks = pred_masks[..., ::-1]

                        projs = result['projs'] * num_classes
                        projs = projs.type(
                            torch.IntTensor).cpu().numpy().astype(np.uint8)
                        projs = onehot_to_image(projs, num_classes)
                        projs = np.transpose(projs, (0, 3, 1, 2))
                        projs = projs.astype(np.float32) / 255.0
                        projs = projs[..., ::-1]

                        # Concatenate all images:
                        output = np.concatenate(
                            (result['imgs'], pred_masks, projs), axis=2)

                        # Save the results for tensorboard vizualization:
                        writer.add_images('output', output, global_step)

        if cp_dir is not None:
            try:
                os.mkdir(cp_dir)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(), cp_dir + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()