Esempio n. 1
0
def check_wso():
    from rsna19.data import dataset
    import albumentations.pytorch
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt

    wso = WSO()

    dataset_valid = dataset.IntracranialDataset(
        csv_file='5fold.csv',
        folds=[0],
        preprocess_func=albumentations.pytorch.ToTensorV2(),
    )
    batch_size = 2
    data_loader = DataLoader(dataset_valid,
                             shuffle=False,
                             num_workers=16,
                             batch_size=batch_size)

    for data in data_loader:
        img = data['image'].float().cpu()

        windowed_img = wso(img).detach().numpy()

        fig, ax = plt.subplots(4, 1)

        for batch in range(batch_size):
            for j in range(4):
                # for k in range(4):
                ax[j].imshow(windowed_img[batch, j], cmap='gray')

        plt.show()
Esempio n. 2
0
def check_heatmap(model_name, fold, epoch, run=None):
    model_str = build_model_str(model_name, fold, run)
    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    print('\n', model_name, '\n')

    model = model_info.factory(**model_info.args)
    model = model.cpu()

    dataset_valid = dataset.IntracranialDataset(
        csv_file='5fold.csv',
        folds=[fold],
        preprocess_func=albumentations.pytorch.ToTensorV2(),
        **model_info.dataset_args)

    model.eval()
    checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.cpu()

    batch_size = 1

    data_loader = DataLoader(dataset_valid,
                             shuffle=False,
                             num_workers=16,
                             batch_size=batch_size)

    data_iter = tqdm(enumerate(data_loader), total=len(data_loader))
    for iter_num, data in data_iter:
        img = data['image'].float().cpu()
        labels = data['labels'].detach().numpy()

        with torch.set_grad_enabled(False):
            pred2d, heatmap, pred = model(img,
                                          output_heatmap=True,
                                          output_per_pixel=True)
            heatmap *= np.prod(heatmap.shape[1:])

            pred2d = (pred2d[0]).detach().cpu().numpy() * 0.1

            fig, ax = plt.subplots(2, 4)

            for i in range(batch_size):
                print(labels[i], torch.sigmoid(pred[i]))
                ax[0, 0].imshow(img[i, 0].cpu().detach().numpy(), cmap='gray')
                ax[0, 1].imshow(heatmap[i, 0].cpu().detach().numpy(),
                                cmap='gray')
                ax[0, 2].imshow(pred2d[0], cmap='gray', vmin=0, vmax=1)
                ax[0, 3].imshow(pred2d[1], cmap='gray', vmin=0, vmax=1)
                ax[1, 0].imshow(pred2d[2], cmap='gray', vmin=0, vmax=1)
                ax[1, 1].imshow(pred2d[3], cmap='gray', vmin=0, vmax=1)
                ax[1, 2].imshow(pred2d[4], cmap='gray', vmin=0, vmax=1)
                ax[1, 3].imshow(pred2d[5], cmap='gray', vmin=0, vmax=1)

            plt.show()
Esempio n. 3
0
def check_windows(model_name, fold, epoch, run=None):
    model_str = build_model_str(model_name, fold, run)
    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    print('\n', model_name, '\n')

    model = model_info.factory(**model_info.args)
    model = model.cpu()

    dataset_valid = dataset.IntracranialDataset(
        csv_file='5fold.csv',
        folds=[fold],
        preprocess_func=albumentations.pytorch.ToTensorV2(),
        **model_info.dataset_args)

    model.eval()
    checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.cpu()

    w = model.windows_conv.weight.detach().cpu().numpy().flatten()
    b = model.windows_conv.bias.detach().cpu().numpy()
    print(w, b)
    for wi, bi in zip(w, b):
        print(f'{-int(bi/wi*1000)} +- {int(abs(1000/wi))}')

    batch_size = 1

    data_loader = DataLoader(dataset_valid,
                             shuffle=False,
                             num_workers=16,
                             batch_size=batch_size)

    data_iter = tqdm(enumerate(data_loader), total=len(data_loader))
    for iter_num, data in data_iter:
        img = data['image'].float().cpu()
        labels = data['labels'].detach().numpy()

        with torch.set_grad_enabled(False):
            windowed_img = model.windows_conv(img)
            windowed_img = F.relu6(windowed_img).cpu().numpy()

            fig, ax = plt.subplots(4, 4)

            for batch in range(batch_size):
                print(labels[batch], data['path'][batch])
                for j in range(4):
                    for k in range(4):
                        ax[j, k].imshow(windowed_img[batch, j * 4 + k],
                                        cmap='gray')

            plt.show()
Esempio n. 4
0
def predict(model_name,
            fold,
            epoch,
            is_test,
            df_out_path,
            mode='normal',
            run=None):
    model_str = build_model_str(model_name, fold, run)
    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    print('\n', model_name, '\n')

    model = model_info.factory(**model_info.args)
    model.output_segmentation = False

    preprocess_func = []
    if 'h_flip' in mode:
        preprocess_func.append(
            albumentations.HorizontalFlip(always_apply=True))
    if 'v_flip' in mode:
        preprocess_func.append(albumentations.VerticalFlip(always_apply=True))
    if 'rot90' in mode:
        preprocess_func.append(Rotate90(always_apply=True))

    dataset_valid = dataset.IntracranialDataset(
        csv_file='test2.csv' if is_test else '5fold.csv',
        folds=[fold],
        preprocess_func=albumentations.Compose(preprocess_func),
        return_labels=not is_test,
        is_test=is_test,
        **{
            **model_info.dataset_args, "add_segmentation_masks": False,
            "segmentation_oversample": 1
        })

    model.eval()
    print(f'load {checkpoints_dir}/{epoch:03}.pt')
    checkpoint = torch.load(f'{checkpoints_dir}/{epoch:03}.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.cuda()

    data_loader = DataLoader(dataset_valid,
                             shuffle=False,
                             num_workers=8,
                             batch_size=model_info.batch_size * 2)

    all_paths = []
    all_study_id = []
    all_slice_num = []
    all_gt = []
    all_pred = []

    data_iter = tqdm(enumerate(data_loader), total=len(data_loader))
    for iter_num, batch in data_iter:
        with torch.set_grad_enabled(False):
            y_hat = torch.sigmoid(model(batch['image'].float().cuda()))
            all_pred.append(y_hat.cpu().numpy())
            all_paths.extend(batch['path'])
            all_study_id.extend(batch['study_id'])
            all_slice_num.extend(batch['slice_num'].cpu().numpy())

            if not is_test:
                y = batch['labels']
                all_gt.append(y.numpy())

    pred_columns = [
        'pred_epidural', 'pred_intraparenchymal', 'pred_intraventricular',
        'pred_subarachnoid', 'pred_subdural', 'pred_any'
    ]
    gt_columns = [
        'gt_epidural', 'gt_intraparenchymal', 'gt_intraventricular',
        'gt_subarachnoid', 'gt_subdural', 'gt_any'
    ]

    if is_test:
        all_pred = np.concatenate(all_pred)
        df = pd.DataFrame(all_pred, columns=pred_columns)
    else:
        all_pred = np.concatenate(all_pred)
        all_gt = np.concatenate(all_gt)
        df = pd.DataFrame(np.hstack((all_gt, all_pred)),
                          columns=gt_columns + pred_columns)

    df = pd.concat((df,
                    pd.DataFrame({
                        'path': all_paths,
                        'study_id': all_study_id,
                        'slice_num': all_slice_num
                    })),
                   axis=1)
    df.to_csv(df_out_path, index=False)
Esempio n. 5
0
def train(model_name, fold, run=None, resume_epoch=-1, use_apex=False):
    model_str = build_model_str(model_name, fold, run)

    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}'
    oof_dir = f'{BaseConfig.oof_dir}/{model_str}'
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print('\n', model_name, '\n')

    logger = SummaryWriter(log_dir=tensorboard_dir)

    model = model_info.factory(**model_info.args)
    model = model.cuda()

    # try:
    #     torchsummary.summary(model, (4, 512, 512))
    #     print('\n', model_name, '\n')
    # except:
    #     raise
    #     pass

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    augmentations = [
        albumentations.ShiftScaleRotate(shift_limit=16. / 256,
                                        scale_limit=0.05,
                                        rotate_limit=30,
                                        interpolation=cv2.INTER_LINEAR,
                                        border_mode=cv2.BORDER_REPLICATE,
                                        p=0.80),
    ]
    if model_info.use_vflip:
        augmentations += [
            albumentations.Flip(),
            albumentations.RandomRotate90()
        ]
    else:
        augmentations += [albumentations.HorizontalFlip()]

    dataset_train = dataset.IntracranialDataset(
        csv_file='5fold-test-rev3.csv',
        folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
        preprocess_func=albumentations.Compose(augmentations),
        **model_info.dataset_args)

    dataset_valid = dataset.IntracranialDataset(csv_file='5fold-test-rev3.csv',
                                                folds=[fold],
                                                preprocess_func=None,
                                                **model_info.dataset_args)

    data_loaders = {
        'train':
        DataLoader(dataset_train,
                   num_workers=8,
                   shuffle=True,
                   batch_size=model_info.batch_size),
        'val':
        DataLoader(dataset_valid,
                   shuffle=False,
                   num_workers=8,
                   batch_size=model_info.batch_size)
    }

    if model_info.single_slice_steps > 0:
        augmentations = [
            albumentations.ShiftScaleRotate(shift_limit=16. / 256,
                                            scale_limit=0.05,
                                            rotate_limit=30,
                                            interpolation=cv2.INTER_LINEAR,
                                            border_mode=cv2.BORDER_REPLICATE,
                                            p=0.80),
        ]
        if model_info.use_vflip:
            augmentations += [
                albumentations.Flip(),
                albumentations.RandomRotate90()
            ]
        else:
            augmentations += [albumentations.HorizontalFlip()]

        dataset_train_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-test-rev3.csv',
            folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
            preprocess_func=albumentations.Compose(augmentations),
            **{
                **model_info.dataset_args, "num_slices": 1
            })

        dataset_valid_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-test-rev3.csv',
            folds=[fold],
            preprocess_func=None,
            **{
                **model_info.dataset_args, "num_slices": 1
            })

        data_loaders['train_1_slice'] = DataLoader(
            dataset_train_1_slice,
            num_workers=8,
            shuffle=True,
            batch_size=model_info.batch_size * 2)
        data_loaders['val_1_slice'] = DataLoader(
            dataset_valid_1_slice,
            shuffle=False,
            num_workers=8,
            batch_size=model_info.batch_size * 2)

    model.train()

    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda()

    def criterium(y_pred, y_true):
        return F.binary_cross_entropy_with_logits(
            y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1))

    # fit the new layers first:
    if resume_epoch == -1 and model_info.is_pretrained:
        model.train()
        model.freeze_encoder()
        data_loader = data_loaders.get('train_1_slice', data_loaders['train'])
        pre_fit_steps = 40000 // model_info.batch_size
        data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps)
        epoch_loss = []
        initial_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        for iter_num, data in data_iter:
            if iter_num > pre_fit_steps:
                break
            with torch.set_grad_enabled(True):
                img = data['image'].float().cuda()
                labels = data['labels'].cuda()
                pred = model(img)
                loss = criterium(pred, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0)
                initial_optimizer.step()
                initial_optimizer.zero_grad()
                epoch_loss.append(float(loss))

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                )
        model.unfreeze_encoder()

    optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr)
    if use_apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    milestones = [5, 10, 16]
    if model_info.optimiser_milestones:
        milestones = model_info.optimiser_milestones
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=0.2)

    print(
        f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}'
    )

    if resume_epoch > -1:
        checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt')
        print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'amp' in checkpoint:
            amp.load_state_dict(checkpoint['amp'])

    for epoch_num in range(resume_epoch + 1, 7):
        for phase in ['train', 'val']:
            model.train(phase == 'train')
            epoch_loss = []
            epoch_labels = []
            epoch_predictions = []
            epoch_sample_paths = []

            if 'on_epoch' in model.__dir__():
                model.on_epoch(epoch_num)

            if epoch_num < model_info.single_slice_steps:
                data_loader = data_loaders[phase + '_1_slice']
                print("use 1 slice input")
            else:
                data_loader = data_loaders[phase]
                print("use N slices input")

            # if epoch_num == model_info.single_slice_steps:
            #     print("train only conv slices/fn layers")
            #     model.module.freeze_encoder_full()
            #
            # if epoch_num == model_info.single_slice_steps+1:
            #     print("train all")
            #     model.module.unfreeze_encoder()
            #
            # if -1 < model_info.freeze_bn_step <= epoch_num:
            #     print("freeze bn")
            #     model.module.freeze_bn()

            data_iter = tqdm(enumerate(data_loader),
                             total=len(data_loader),
                             ncols=200)
            for iter_num, data in data_iter:
                img = data['image'].float().cuda()
                labels = data['labels'].float().cuda()

                with torch.set_grad_enabled(phase == 'train'):
                    # if epoch_num == model_info.single_slice_steps and phase == 'train':
                    #     with torch.set_grad_enabled(False):
                    #         model_x = model(img, output_before_combine_slices=True)
                    #     with torch.set_grad_enabled(True):
                    #         pred = model(model_x.detach(), train_last_layers_only=True)
                    # else:
                    pred = model(img)
                    loss = criterium(pred, labels)

                    if phase == 'train':
                        if use_apex:
                            with amp.scale_loss(
                                    loss / model_info.accumulation_steps,
                                    optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            (loss / model_info.accumulation_steps).backward()

                        if (iter_num + 1) % model_info.accumulation_steps == 0:
                            # if not use_apex:
                            #     torch.nn.utils.clip_grad_norm_(model.parameters(), 32.0)
                            optimizer.step()
                            optimizer.zero_grad()

                    epoch_loss.append(float(loss))

                    epoch_labels.append(labels.detach().cpu().numpy())
                    epoch_predictions.append(
                        torch.sigmoid(pred).detach().cpu().numpy())
                    epoch_sample_paths += data['path']

                data_iter.set_description(
                    f'{epoch_num} Loss: Running {np.mean(epoch_loss[-1000:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                )

            logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num)
            logger.add_scalar('lr', optimizer.param_groups[0]['lr'],
                              epoch_num)  # scheduler.get_lr()[0]
            try:
                epoch_labels = np.row_stack(epoch_labels)
                epoch_predictions = np.row_stack(epoch_predictions)
                print(epoch_labels.shape, epoch_predictions.shape)
                log_metrics(logger=logger,
                            phase=phase,
                            epoch_num=epoch_num,
                            y=epoch_labels,
                            y_hat=epoch_predictions)
            except Exception:
                pass
            logger.flush()

            if phase == 'val':
                scheduler.step(epoch=epoch_num)
                torch.save(
                    {
                        'epoch': epoch_num,
                        'sample_paths': epoch_sample_paths,
                        'epoch_labels': epoch_labels,
                        'epoch_predictions': epoch_predictions,
                    }, f'{oof_dir}/{epoch_num:03}.pt')
            else:
                # print(f'{checkpoints_dir}/{epoch_num:03}.pt')
                torch.save(
                    {
                        'epoch': epoch_num,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'amp': amp.state_dict()
                    }, f'{checkpoints_dir}/{epoch_num:03}.pt')
Esempio n. 6
0
def train(model_name, fold, run=None, resume_epoch=-1):
    model_str = build_model_str(model_name, fold, run)

    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}'
    oof_dir = f'{BaseConfig.oof_dir}/{model_str}'
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print('\n', model_name, '\n')

    logger = SummaryWriter(log_dir=tensorboard_dir)

    model = model_info.factory(**model_info.args)
    model = model.cuda()

    # try:
    #     torchsummary.summary(model, (4, 512, 512))
    #     print('\n', model_name, '\n')
    # except:
    #     raise
    #     pass

    model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    dataset_train = dataset.IntracranialDataset(
        csv_file='5fold-rev3.csv',
        folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
        preprocess_func=albumentations.Compose([
            albumentations.ShiftScaleRotate(shift_limit=16. / 256,
                                            scale_limit=0.05,
                                            rotate_limit=30,
                                            interpolation=cv2.INTER_LINEAR,
                                            border_mode=cv2.BORDER_REPLICATE,
                                            p=0.7),
            albumentations.Flip(),
            albumentations.RandomRotate90(),
        ]),
        **{
            **model_info.dataset_args, "segmentation_oversample": 1
        })

    dataset_valid = dataset.IntracranialDataset(
        csv_file='5fold.csv',
        folds=[fold],
        preprocess_func=None,
        **{
            **model_info.dataset_args, "segmentation_oversample": 1
        })

    data_loaders = {
        'train':
        DataLoader(dataset_train,
                   num_workers=8,
                   shuffle=True,
                   batch_size=model_info.batch_size),
        'val':
        DataLoader(dataset_valid,
                   shuffle=False,
                   num_workers=8,
                   batch_size=model_info.batch_size)
    }

    dataset_train_1_slice = None
    if model_info.single_slice_steps > 0:
        dataset_train_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-rev3.csv',
            folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
            preprocess_func=albumentations.Compose([
                albumentations.ShiftScaleRotate(
                    shift_limit=16. / 256,
                    scale_limit=0.05,
                    rotate_limit=30,
                    interpolation=cv2.INTER_LINEAR,
                    border_mode=cv2.BORDER_REPLICATE,
                    p=0.75),
                albumentations.Flip(),
                albumentations.RandomRotate90()
            ]),
            **{
                **model_info.dataset_args, "num_slices": 1
            })

        dataset_valid_1_slice = dataset.IntracranialDataset(
            csv_file='5fold.csv',
            folds=[fold],
            preprocess_func=None,
            **{
                **model_info.dataset_args, "num_slices": 1,
                "segmentation_oversample": 1
            })

        data_loaders['train_1_slice'] = DataLoader(
            dataset_train_1_slice,
            num_workers=8,
            shuffle=True,
            batch_size=model_info.batch_size * 2)
        data_loaders['val_1_slice'] = DataLoader(
            dataset_valid_1_slice,
            shuffle=False,
            num_workers=8,
            batch_size=model_info.batch_size * 2)

    model.train()
    optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr)

    milestones = [5, 10, 16]
    if model_info.optimiser_milestones:
        milestones = model_info.optimiser_milestones
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=0.2)

    print(
        f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}'
    )

    if resume_epoch > -1:
        checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt')
        print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt')
        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda()

    def criterium(y_pred, y_true):
        return F.binary_cross_entropy_with_logits(
            y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1))

    def criterium_mask(y_pred, y_true, have_segmentation):
        if not max(have_segmentation):
            return 0
        return F.binary_cross_entropy(y_pred[have_segmentation],
                                      y_true[have_segmentation]) * 10

    # criterium = nn.BCEWithLogitsLoss()

    # fit new layers first:
    if resume_epoch == -1 and model_info.is_pretrained:
        model.train()
        model.module.freeze_encoder()
        data_loader = data_loaders.get('train_1_slice', data_loaders['train'])
        pre_fit_steps = 50000 // model_info.batch_size
        data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps)
        epoch_loss = []
        epoch_loss_mask = []
        initial_optimizer = radam.RAdam(model.parameters(), lr=1e-4)
        for iter_num, data in data_iter:
            if iter_num > pre_fit_steps:
                break
            with torch.set_grad_enabled(True):
                img = data['image'].float().cuda()
                labels = data['labels'].cuda()
                segmentation_labels = data['seg'].cuda()
                have_segmentation = data['have_segmentation']
                have_any_segmentation = max(have_segmentation)

                pred, segmentation = model(img)

                loss_cls = criterium(pred, labels)
                loss_mask = criterium_mask(
                    segmentation, F.max_pool2d(segmentation_labels, 4),
                    have_segmentation)
                (loss_cls + loss_mask).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0)
                initial_optimizer.step()
                initial_optimizer.zero_grad()
                epoch_loss.append(float(loss_cls))
                if have_any_segmentation:
                    epoch_loss_mask.append(float(loss_mask))

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                    +
                    f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}'
                )
    model.module.unfreeze_encoder()

    for epoch_num in range(resume_epoch + 1, 8):
        if epoch_num > 3 and dataset_train_1_slice is not None:
            dataset_train_1_slice.segmentation_oversample = 1

        for phase in ['train', 'val']:
            model.train(phase == 'train')
            epoch_loss = []
            epoch_loss_mask = []
            epoch_labels = []
            epoch_predictions = []
            epoch_sample_paths = []

            if 'on_epoch' in model.module.__dir__():
                model.module.on_epoch(epoch_num)

            if epoch_num < model_info.single_slice_steps:
                data_loader = data_loaders[phase + '_1_slice']
                print("use 1 slice input")
            else:
                data_loader = data_loaders[phase]
                print("use N slices input")

            # if epoch_num == model_info.single_slice_steps:
            #     print("train only conv slices/fn layers")
            #     model.module.freeze_encoder_full()
            #
            # if epoch_num == model_info.single_slice_steps+1:
            #     print("train all")
            #     model.module.unfreeze_encoder()
            #
            # if -1 < model_info.freeze_bn_step <= epoch_num:
            #     print("freeze bn")
            #     model.module.freeze_bn()

            data_iter = tqdm(enumerate(data_loader), total=len(data_loader))
            for iter_num, data in data_iter:
                img = data['image'].float().cuda()
                labels = data['labels'].float().cuda()
                segmentation_labels = data['seg'].cuda()
                have_segmentation = data['have_segmentation']
                have_any_segmentation = max(have_segmentation)

                with torch.set_grad_enabled(phase == 'train'):
                    pred, segmentation = model(img)

                    loss_cls = criterium(pred, labels)
                    loss_mask = criterium_mask(
                        segmentation, F.max_pool2d(segmentation_labels, 4),
                        have_segmentation)

                    if phase == 'train':
                        ((loss_cls + loss_mask) /
                         model_info.accumulation_steps).backward()
                        if (iter_num + 1) % model_info.accumulation_steps == 0:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), 16.0)
                            optimizer.step()
                            optimizer.zero_grad()

                    epoch_loss.append(float(loss_cls))
                    if have_any_segmentation:
                        epoch_loss_mask.append(float(loss_mask))

                    epoch_labels.append(labels.detach().cpu().numpy())
                    epoch_predictions.append(
                        torch.sigmoid(pred).detach().cpu().numpy())
                    epoch_sample_paths += data['path']

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                    +
                    f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}'
                )

            epoch_labels = np.row_stack(epoch_labels)
            epoch_predictions = np.row_stack(epoch_predictions)

            logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num)
            logger.add_scalar(f'loss_mask_{phase}', np.mean(epoch_loss_mask),
                              epoch_num)
            logger.add_scalar('lr', optimizer.param_groups[0]['lr'],
                              epoch_num)  # scheduler.get_lr()[0]
            try:
                log_metrics(logger=logger,
                            phase=phase,
                            epoch_num=epoch_num,
                            y=epoch_labels,
                            y_hat=epoch_predictions)
            except Exception:
                pass
            logger.flush()

            if phase == 'val':
                scheduler.step(epoch=epoch_num)
                torch.save(
                    {
                        'epoch': epoch_num,
                        'sample_paths': epoch_sample_paths,
                        'epoch_labels': epoch_labels,
                        'epoch_predictions': epoch_predictions,
                    }, f'{oof_dir}/{epoch_num:03}.pt')
            else:
                # print(f'{checkpoints_dir}/{epoch_num:03}.pt')
                torch.save(
                    {
                        'epoch': epoch_num,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }, f'{checkpoints_dir}/{epoch_num:03}.pt')