@patches_grid(patch_size=PATCH_SIZE, stride=7 * PATCH_SIZE // 8) @add_extract_dims(1) def predict(image): return inference_step(image, architecture=model, activation=nn.Softmax(1)) # metrics individual_metrics = { **{f'dice_{c}': multiclass_to_bool(drop_spacing(dice_score), c) for c in range(1, dataset.n_classes)}, **{f'surface_dice_{c}': multiclass_to_bool(surface_dice, c) for c in range(1, dataset.n_classes)} } val_metrics = convert_to_aggregated(individual_metrics) # run experiment logger = TBLogger(EXPERIMENT_PATH / FOLD / 'logs') commands.populate(EXPERIMENT_PATH / 'config.json', save_json, CONFIG, EXPERIMENT_PATH / 'config.json') commands.populate(EXPERIMENT_PATH / FOLD / 'model.pth', lambda : [ train(train_step, batch_iter, n_epochs=CONFIG['n_epochs'], logger=logger, validate=lambda : compute_metrics(predict, train_dataset.load_image, lambda i: (train_dataset.load_gt(i), train_dataset.load_spacing(i)), val_ids, val_metrics), architecture=model, optimizer=optimizer, criterion=criterion, lr=CONFIG['lr']), save_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth') ]) load_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth') for target_slice_spacing in CONFIG['target_slice_spacing']: test_dataset = ChangeSliceSpacing(dataset, new_slice_spacing=target_slice_spacing) commands.predict( ids=test_ids,
def train(root, dataset, generator, discriminator, latent_dim, disc_iters, batch_size, batches_per_epoch, n_epochs, r1_weight, lr_discriminator, lr_generator, device='cuda'): root = Path(root) generator, discriminator = generator.to(device), discriminator.to(device) gen_opt = torch.optim.Adam(generator.parameters(), lr_generator) disc_opt = torch.optim.Adam(discriminator.parameters(), lr_discriminator) batch_iter = Infinite( sample(dataset.ids), dataset.image, combiner=lambda images: [ tuple( np.array(images).reshape(disc_iters, len(images) // disc_iters, 1, *images[ 0].shape)) ], buffer_size=10, batch_size=batch_size * disc_iters, batches_per_epoch=batches_per_epoch, ) def val(): generator.eval() t = sample_on_sphere(latent_dim, 144, state=0).astype('float32') with torch.no_grad(): y = to_np(generator(to_var(t, device=generator)))[:, 0] return {'generated__image': bytescale(stitch(y, bytescale))} logger = TBLogger(root / 'logs') train_base( train_step, batch_iter, n_epochs, logger=logger, validate=val, checkpoints=Checkpoints(root / 'checkpoints', [generator, discriminator, gen_opt, disc_opt]), # lr=lr, generator=generator, discriminator=discriminator, gen_optimizer=gen_opt, disc_optimizer=disc_opt, latent_dim=latent_dim, r1_weight=r1_weight, time=TimeProfiler(logger.logger), tqdm=TQDM(False), ) save_model_state(generator, root / 'generator.pth') save_model_state(discriminator, root / 'discriminator.pth')