Esempio n. 1
0
@patches_grid(patch_size=PATCH_SIZE, stride=7 * PATCH_SIZE // 8)
@add_extract_dims(1)
def predict(image):
    return inference_step(image, architecture=model, activation=nn.Softmax(1))


# metrics
individual_metrics = {
    **{f'dice_{c}': multiclass_to_bool(drop_spacing(dice_score), c) for c in range(1, dataset.n_classes)},
    **{f'surface_dice_{c}': multiclass_to_bool(surface_dice, c) for c in range(1, dataset.n_classes)}
}
val_metrics = convert_to_aggregated(individual_metrics)


# run experiment
logger = TBLogger(EXPERIMENT_PATH / FOLD / 'logs')
commands.populate(EXPERIMENT_PATH / 'config.json', save_json, CONFIG, EXPERIMENT_PATH / 'config.json')
commands.populate(EXPERIMENT_PATH / FOLD / 'model.pth', lambda : [
    train(train_step, batch_iter, n_epochs=CONFIG['n_epochs'], logger=logger,
          validate=lambda : compute_metrics(predict, train_dataset.load_image,
                                            lambda i: (train_dataset.load_gt(i), train_dataset.load_spacing(i)),
                                            val_ids, val_metrics),
          architecture=model, optimizer=optimizer, criterion=criterion, lr=CONFIG['lr']),
    save_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth')
])

load_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth')
for target_slice_spacing in CONFIG['target_slice_spacing']:
    test_dataset = ChangeSliceSpacing(dataset, new_slice_spacing=target_slice_spacing)
    commands.predict(
        ids=test_ids,
Esempio n. 2
0
def train(root,
          dataset,
          generator,
          discriminator,
          latent_dim,
          disc_iters,
          batch_size,
          batches_per_epoch,
          n_epochs,
          r1_weight,
          lr_discriminator,
          lr_generator,
          device='cuda'):
    root = Path(root)
    generator, discriminator = generator.to(device), discriminator.to(device)

    gen_opt = torch.optim.Adam(generator.parameters(), lr_generator)
    disc_opt = torch.optim.Adam(discriminator.parameters(), lr_discriminator)

    batch_iter = Infinite(
        sample(dataset.ids),
        dataset.image,
        combiner=lambda images: [
            tuple(
                np.array(images).reshape(disc_iters,
                                         len(images) // disc_iters, 1, *images[
                                             0].shape))
        ],
        buffer_size=10,
        batch_size=batch_size * disc_iters,
        batches_per_epoch=batches_per_epoch,
    )

    def val():
        generator.eval()
        t = sample_on_sphere(latent_dim, 144, state=0).astype('float32')
        with torch.no_grad():
            y = to_np(generator(to_var(t, device=generator)))[:, 0]

        return {'generated__image': bytescale(stitch(y, bytescale))}

    logger = TBLogger(root / 'logs')
    train_base(
        train_step,
        batch_iter,
        n_epochs,
        logger=logger,
        validate=val,
        checkpoints=Checkpoints(root / 'checkpoints',
                                [generator, discriminator, gen_opt, disc_opt]),
        # lr=lr,
        generator=generator,
        discriminator=discriminator,
        gen_optimizer=gen_opt,
        disc_optimizer=disc_opt,
        latent_dim=latent_dim,
        r1_weight=r1_weight,
        time=TimeProfiler(logger.logger),
        tqdm=TQDM(False),
    )
    save_model_state(generator, root / 'generator.pth')
    save_model_state(discriminator, root / 'discriminator.pth')