def val_dataloader(self):
     return get_data_loader(
         dataset_name=self.config.dataset.name,
         modalities=self.config.dataset.modalities,
         root_dir_paths=self.config.dataset.root_dir_paths,
         augmentation_type='none',
         use_shuffle=False,
         batch_size=self.config.dataset.batch_size,
         num_workers=self.config.dataset.num_workers,
     )
 def train_dataloader(self):
     data_loader = get_data_loader(
         dataset_name=self.config.dataset.name,
         modalities=self.config.dataset.modalities,
         root_dir_paths=self.config.dataset.root_dir_paths,
         augmentation_type=self.config.dataset.augmentation_type,
         use_shuffle=self.config.dataset.use_shuffle,
         batch_size=self.config.dataset.batch_size,
         num_workers=self.config.dataset.num_workers,
     )
     return data_loader
def main(config, needs_save):
    os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices
    seed = check_manual_seed(config.training.seed)
    print('Using manual seed: {}'.format(seed))

    if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS':
        patient_ids = TRAIN_PATIENT_IDS
    elif config.dataset.patient_ids == 'TEST_PATIENT_IDS':
        patient_ids = TEST_PATIENT_IDS
    else:
        raise NotImplementedError

    data_loader = get_data_loader(
        mode=config.dataset.mode,
        dataset_name=config.dataset.name,
        patient_ids=patient_ids,
        root_dir_path=config.dataset.root_dir_path,
        use_augmentation=config.dataset.use_augmentation,
        batch_size=config.dataset.batch_size,
        num_workers=config.dataset.num_workers,
        image_size=config.dataset.image_size)

    E = Encoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.enc_filters,
                activation=config.model.enc_activation).float()

    D = Decoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.dec_filters,
                activation=config.model.dec_activation,
                final_activation=config.model.dec_final_activation).float()

    if config.model.enc_spectral_norm:
        apply_spectral_norm(E)

    if config.model.dec_spectral_norm:
        apply_spectral_norm(D)

    if config.training.use_cuda:
        E.cuda()
        D.cuda()
        E = nn.DataParallel(E)
        D = nn.DataParallel(D)

    if config.model.saved_E:
        print(config.model.saved_E)
        E.load_state_dict(torch.load(config.model.saved_E))

    if config.model.saved_D:
        print(config.model.saved_D)
        D.load_state_dict(torch.load(config.model.saved_D))

    print(E)
    print(D)

    e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()),
                         config.optimizer.enc_lr, [0.9, 0.9999])

    d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()),
                         config.optimizer.dec_lr, [0.9, 0.9999])

    alpha = config.training.alpha
    beta = config.training.beta
    margin = config.training.margin

    batch_size = config.dataset.batch_size
    fixed_z = torch.randn(calc_latent_dim(config))

    if 'ssim' in config.training.loss:
        ssim_loss = pytorch_ssim.SSIM(window_size=11)

    def l_recon(recon: torch.Tensor, target: torch.Tensor):
        if config.training.loss == 'l2':
            loss = F.mse_loss(recon, target, reduction='sum')

        elif config.training.loss == 'l1':
            loss = F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon)

        elif config.training.loss == 'ssim+l1':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim+l2':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.mse_loss(recon, target, reduction='sum')

        else:
            raise NotImplementedError

        return beta * loss / batch_size

    def l_reg(mu: torch.Tensor, log_var: torch.Tensor):
        loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
        return loss / batch_size

    def update(engine, batch):
        E.train()
        D.train()

        image = norm(batch['image'])

        if config.training.use_cuda:
            image = image.cuda(non_blocking=True).float()
        else:
            image = image.float()

        e_optim.zero_grad()
        d_optim.zero_grad()

        z, z_mu, z_logvar = E(image)
        x_r = D(z)

        l_vae_reg = l_reg(z_mu, z_logvar)
        l_vae_recon = l_recon(x_r, image)
        l_vae_total = l_vae_reg + l_vae_recon

        l_vae_total.backward()

        e_optim.step()
        d_optim.step()

        if config.training.use_cuda:
            torch.cuda.synchronize()

        return {
            'TotalLoss': l_vae_total.item(),
            'EncodeLoss': l_vae_reg.item(),
            'ReconLoss': l_vae_recon.item(),
        }

    output_dir = get_output_dir_path(config)
    trainer = Engine(update)
    timer = Timer(average=True)

    monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss']

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def save_config(engine):
        config_to_save = defaultdict(dict)

        for key, child in config._asdict().items():
            for k, v in child._asdict().items():
                config_to_save[key][k] = v

        config_to_save['seed'] = seed
        config_to_save['output_dir'] = output_dir

        print('Training starts by the following configuration: ',
              config_to_save)

        if needs_save:
            save_path = os.path.join(output_dir, 'config.json')
            with open(save_path, 'w') as f:
                json.dump(config_to_save, f)

    @trainer.on(Events.ITERATION_COMPLETED)
    def show_logs(engine):
        if (engine.state.iteration - 1) % config.save.log_iter_interval == 0:
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=config.training.n_epochs,
                i=engine.state.iteration,
                max_i=len(data_loader))

            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_logs(engine):
        if needs_save:
            fname = os.path.join(output_dir, 'logs.tsv')
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        if needs_save:
            if engine.state.epoch % config.save.save_epoch_interval == 0:
                image = norm(engine.state.batch['image'])

                with torch.no_grad():
                    z, _, _ = E(image)
                    x_r = D(z)
                    x_p = D(fixed_z)

                image = denorm(image).detach().cpu()
                x_r = denorm(x_r).detach().cpu()
                x_p = denorm(x_p).detach().cpu()

                image = image[:config.save.n_save_images, ...]
                x_r = x_r[:config.save.n_save_images, ...]
                x_p = x_p[:config.save.n_save_images, ...]

                save_path = os.path.join(
                    output_dir, 'result_{}.png'.format(engine.state.epoch))
                save_image(torch.cat([image, x_r, x_p]).data, save_path)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.save.n_saved,
            create_dir=True,
        )
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'E': E,
                                      'D': D
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.training.n_epochs, config.training.n_epochs * len(data_loader)))

    trainer.run(data_loader, config.training.n_epochs)
def main(config, needs_save, i):

    if config.run.visible_devices:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices

    assert config.train_dataset.root_dir_path == config.val_dataset.root_dir_path
    # train_patient_ids, val_patient_ids = divide_patients(config.train_dataset.root_dir_path)
    train_patient_ids, val_patient_ids = get_cv_splits(
        config.train_dataset.root_dir_path, i)

    seed = check_manual_seed()
    print('Using seed: {}'.format(seed))

    class_name_to_index = config.label_to_id._asdict()
    index_to_class_name = {v: k for k, v in class_name_to_index.items()}

    train_data_loader = get_data_loader(
        mode='train',
        dataset_name=config.train_dataset.dataset_name,
        root_dir_path=config.train_dataset.root_dir_path,
        patient_ids=train_patient_ids,
        batch_size=config.train_dataset.batch_size,
        num_workers=config.train_dataset.num_workers,
        volume_size=config.train_dataset.volume_size,
    )

    val_data_loader = get_data_loader(
        mode='val',
        dataset_name=config.val_dataset.dataset_name,
        root_dir_path=config.val_dataset.root_dir_path,
        patient_ids=val_patient_ids,
        batch_size=config.val_dataset.batch_size,
        num_workers=config.val_dataset.num_workers,
        volume_size=config.val_dataset.volume_size,
    )

    model = ResUNet(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim,
        filters=config.model.filters,
    )

    print(model)

    if config.run.use_cuda:
        model.cuda()
        model = nn.DataParallel(model)

    if config.model.saved_model:
        print('Loading saved model: {}'.format(config.model.saved_model))
        model.load_state_dict(torch.load(config.model.saved_model))
    else:
        print('Initializing weights.')
        init_weights(model, init_type=config.model.init_type)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=config.optimizer.lr,
                           betas=config.optimizer.betas,
                           weight_decay=config.optimizer.weight_decay)

    dice_loss = SoftDiceLoss()

    focal_loss = FocalLoss(
        gamma=config.focal_loss.gamma,
        alpha=config.focal_loss.alpha,
    )

    active_contour_loss = ActiveContourLoss(
        weight=config.active_contour_loss.weight, )

    dice_coeff = DiceCoefficient(
        n_classes=config.metric.n_classes,
        index_to_class_name=index_to_class_name,
    )

    one_hot_encoder = OneHotEncoder(
        n_classes=config.metric.n_classes, ).forward

    def train(engine, batch):
        adjust_learning_rate(optimizer,
                             engine.state.epoch,
                             initial_lr=config.optimizer.lr,
                             n_epochs=config.run.n_epochs,
                             gamma=config.optimizer.gamma)

        model.train()

        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()

        else:
            image = image.float()
            label = label.long()

        optimizer.zero_grad()

        output = model(image)
        target = one_hot_encoder(label)[:, 1:, ...]

        l_dice = dice_loss(output, target)
        l_focal = focal_loss(output, target)
        l_active_contour = active_contour_loss(output, target)

        l_total = l_dice + l_focal + l_active_contour
        l_total.backward()

        optimizer.step()

        m_dice = dice_coeff.update(output.detach(), label)

        measures = {
            'SoftDiceLoss': l_dice.item(),
            'FocalLoss': l_focal.item(),
            'ActiveContourLoss': l_active_contour.item(),
        }

        measures.update(m_dice)

        if config.run.use_cuda:
            torch.cuda.synchronize()

        return measures

    def evaluate(engine, batch):
        model.eval()

        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()

        else:
            image = image.float()
            label = label.long()

        with torch.no_grad():
            output = model(image)
            target = one_hot_encoder(label)[:, 1:, ...]

            l_dice = dice_loss(output, target)
            l_focal = focal_loss(output, target)
            l_active_contour = active_contour_loss(output, target)

            m_dice = dice_coeff.update(output.detach(), label)

        measures = {
            'SoftDiceLoss': l_dice.item(),
            'FocalLoss': l_focal.item(),
            'ActiveContourLoss': l_active_contour.item(),
        }

        measures.update(m_dice)

        if config.run.use_cuda:
            torch.cuda.synchronize()

        return measures

    output_dir_path = get_output_dir_path(config, i)
    trainer = Engine(train)
    evaluator = Engine(evaluate)
    timer = Timer(average=True)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir_path,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.run.n_epochs + 1,
            create_dir=True,
        )

    monitoring_metrics = ['SoftDiceLoss', 'FocalLoss', 'ActiveContourLoss']
    monitoring_metrics += class_name_to_index.keys()

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    evaluator, metric)

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=monitoring_metrics)
    pbar.attach(evaluator, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def call_save_config(engine):
        if needs_save:
            return save_config(engine, config, seed, output_dir_path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def call_save_logs(engine):
        if needs_save:
            return save_logs('train', engine, config, output_dir_path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def call_print_times(engine):
        return print_times(engine, config, pbar, timer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        evaluator.run(val_data_loader, 1)

        if needs_save:
            save_logs('val', evaluator, config, output_dir_path)
            save_images(evaluator, trainer.state.epoch)

    def save_images(evaluator, epoch):
        batch = evaluator.state.batch
        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()
        else:
            image = image.float()
            label = label.long()

        with torch.no_grad():
            pred = model(image)

        output = torch.ones_like(label)

        mask_0 = pred[:, 0, ...] < 0.5
        mask_1 = pred[:, 1, ...] < 0.5
        mask_2 = pred[:, 2, ...] < 0.5
        mask = mask_0 * mask_1 * mask_2

        pred = pred.argmax(1)
        output += pred

        output[mask] = 0

        image = image.detach().cpu().float()
        label = label.detach().cpu().unsqueeze(1).float()
        output = output.detach().cpu().unsqueeze(1).float()

        z_middle = image.shape[-1] // 2
        image = image[:, 0, ..., z_middle]
        label = label[:, 0, ..., z_middle]
        output = output[:, 0, ..., z_middle]

        if config.save.image_vmax is not None:
            vmax = config.save.image_vmax
        else:
            vmax = image.max()

        if config.save.image_vmin is not None:
            vmin = config.save.image_vmin
        else:
            vmin = image.min()

        image = np.clip(image, vmin, vmax)
        image -= vmin
        image /= (vmax - vmin)
        image *= 255.0

        save_path = os.path.join(output_dir_path,
                                 'result_{}.png'.format(epoch))
        save_images_via_plt(image, label, output, config.save.n_save_images,
                            config, save_path)

    if needs_save:
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'model': model,
                                      'optim': optimizer
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.run.n_epochs, config.run.n_epochs * len(train_data_loader)))

    trainer.run(train_data_loader, config.run.n_epochs)
Exemple #5
0
def inference(config):

    if config.run.visible_devices:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices

    test_patient_ids = os.listdir(config.test_dataset.root_dir_path)

    seed = check_manual_seed()
    print('Using seed: {}'.format(seed))

    class_name_to_index = config.label_to_id._asdict()
    index_to_class_name = {v: k for k, v in class_name_to_index.items()}

    test_data_loader = get_data_loader(
        mode='test',
        dataset_name=config.test_dataset.dataset_name,
        root_dir_path=config.test_dataset.root_dir_path,
        patient_ids=test_patient_ids,
        batch_size=config.test_dataset.batch_size,
        num_workers=config.test_dataset.num_workers,
        volume_size=config.test_dataset.volume_size,
    )

    model_1 = get_trained_model(config.model_1)
    model_2 = get_trained_model(config.model_2)
    model_3 = get_trained_model(config.model_3)
    model_4 = get_trained_model(config.model_4)
    model_5 = get_trained_model(config.model_5)

    model_1.eval()
    model_2.eval()
    model_3.eval()
    model_4.eval()
    model_5.eval()

    for batch in tqdm(test_data_loader):
        image = batch['image'].cuda().float()
        assert image.size(0) == 1

        patient_id = batch['patient_id'][0]
        nii_path = batch['nii_path'][0]

        image = F.pad(image, (2, 3, 0, 0, 0, 0, 0, 0, 0, 0), 'constant', 0)
        output = torch.ones((1, image.shape[2], image.shape[3], image.shape[4]))

        with torch.no_grad():
            pred_1 = model_1(image)
            pred_2 = model_2(image)
            pred_3 = model_3(image)
            pred_4 = model_4(image)
            pred_5 = model_5(image)

            pred = (pred_1 + pred_2 + pred_3 + pred_4 + pred_5) / 5.0

        mask_0 = pred[:, 0, ...] < 0.5
        mask_1 = pred[:, 1, ...] < 0.5
        mask_2 = pred[:, 2, ...] < 0.5
        mask = mask_0 * mask_1 * mask_2

        pred = pred.argmax(1).cpu()
        output += pred

        output[mask] = 0

        image = image[..., 2:-3]
        output = output[..., 2:-3]

        save_dir_path = os.path.join(config.save.save_root_dir, patient_id)
        os.makedirs(save_dir_path, exist_ok=True)

        image = image.cpu().numpy()[0, 1, ...]
        output = output.cpu().numpy()[0, ...].astype(np.int16)

        nii_image = nib.load(nii_path)

        nii_output = nib.Nifti1Image(output, affine=nii_image.affine)

        nib.save(nii_output, os.path.join(os.path.join(
            save_dir_path, patient_id + '_output.nii.gz'))
        )