コード例 #1
0
def train_autoencoder(model,
                      optimizer,
                      criterion,
                      batch_iter,
                      n_epochs,
                      train_dataset,
                      path,
                      batch_size=200,
                      length=500):
    '''
    :param model: PyTorch Neural Network model
    :param optimizer: Torch optimization strategy: SGD, Adam, AdaDelta, ...
    :param criterion: Loss function
    :param batch_iter: Batch iterator function
    :param n_epochs:  Number of epochs to train model
    :param train_dataset:  Dataset object to train model
    :param path: Experiment output path
    :param batch_size: Size of batches
    :param length: Length of crop to load train data
    :return: None
    '''
    # initial setup
    path = Path(path)
    logger = NamedTBLogger(path / 'logs', ['loss'])
    model.eval()

    for step in tqdm(range(n_epochs)):

        model.train()
        losses = []
        for inputs in batch_iter(train_dataset, batch_size, length):

            *inputs, target, = sequence_to_var(*tuple(inputs[:2]),
                                               cuda=is_on_cuda(model))

            logits = model(*inputs)

            if isinstance(criterion, nn.BCELoss):
                target = target.float()

            total = criterion(logits, target)

            optimizer.zero_grad()
            total.backward()
            optimizer.step()

            losses.append(sequence_to_np(total))

        logger.train(losses, step)
        print(f'Loss: {losses}')

    save_model_state(model, path / 'model.pth')
コード例 #2
0
ファイル: unet_3d.py プロジェクト: migonch/two_and_half_d
    **{f'dice_{c}': multiclass_to_bool(drop_spacing(dice_score), c) for c in range(1, dataset.n_classes)},
    **{f'surface_dice_{c}': multiclass_to_bool(surface_dice, c) for c in range(1, dataset.n_classes)}
}
val_metrics = convert_to_aggregated(individual_metrics)


# run experiment
logger = TBLogger(EXPERIMENT_PATH / FOLD / 'logs')
commands.populate(EXPERIMENT_PATH / 'config.json', save_json, CONFIG, EXPERIMENT_PATH / 'config.json')
commands.populate(EXPERIMENT_PATH / FOLD / 'model.pth', lambda : [
    train(train_step, batch_iter, n_epochs=CONFIG['n_epochs'], logger=logger,
          validate=lambda : compute_metrics(predict, train_dataset.load_image,
                                            lambda i: (train_dataset.load_gt(i), train_dataset.load_spacing(i)),
                                            val_ids, val_metrics),
          architecture=model, optimizer=optimizer, criterion=criterion, lr=CONFIG['lr']),
    save_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth')
])

load_model_state(model, EXPERIMENT_PATH / FOLD / 'model.pth')
for target_slice_spacing in CONFIG['target_slice_spacing']:
    test_dataset = ChangeSliceSpacing(dataset, new_slice_spacing=target_slice_spacing)
    commands.predict(
        ids=test_ids,
        output_path=EXPERIMENT_PATH / FOLD / f"predictions_{CONFIG['source_slice_spacing']}_to_{target_slice_spacing}",
        load_x=test_dataset.load_image,
        predict_fn=predict
    )
    commands.evaluate_individual_metrics(
        load_y_true=lambda i: (test_dataset.load_gt(i), test_dataset.load_spacing(i)),
        metrics=individual_metrics,
        predictions_path=EXPERIMENT_PATH / FOLD / f"predictions_{CONFIG['source_slice_spacing']}_to_{target_slice_spacing}",
コード例 #3
0
# for more details about the batch iterators we use

batch_iter = Infinite(
    # get a random pair of paths
    sample(paths),
    # load the image-contour pair
    unpack_args(load_pair),
    # get a random slice
    unpack_args(get_random_slice),
    # simple augmentation
    unpack_args(random_flip),
    batch_size=batch_size,
    batches_per_epoch=samples_per_train // (batch_size * n_epochs),
    combiner=combiner)

model = to_device(Network(), args.device)
optimizer = Adam(model.parameters(), lr=lr)

# Here we use a general training function with a custom `train_step`.
# See the tutorial for more details: https://deep-pipe.readthedocs.io/en/latest/tutorials/training.html

train(
    train_step,
    batch_iter,
    n_epochs=n_epochs,
    logger=ConsoleLogger(),
    # additional arguments to `train_step`
    model=model,
    optimizer=optimizer)
save_model_state(model, args.output)
コード例 #4
0
def train(model,
          optimizer,
          criterion,
          batch_iter,
          n_epochs,
          train_dataset,
          val_data,
          val_labels,
          path,
          batch_size=200,
          length=500):
    '''
    :param model: PyTorch Neural Network model
    :param optimizer: Torch optimization strategy: SGD, Adam, AdaDelta, ...
    :param criterion: Loss function
    :param batch_iter: Batch iterator function
    :param n_epochs:  Number of epochs to train model
    :param train_dataset:  Dataset object to train model
    :param val_data: Loaded numpy data to evaluate model
    :param val_labels: Loaded numpy target to evaluate model
    :param path: Experiment output path
    :param batch_size: Size of batches
    :param length: Length of crop to load train data
    :return: None
    '''
    # initial setup
    path = Path(path)
    logger = NamedTBLogger(path / 'logs', ['loss'])
    model.eval()

    best_score = None

    for step in tqdm(range(n_epochs)):

        model.train()
        losses = []
        for inputs in batch_iter(train_dataset, batch_size, length):
            *inputs, target = sequence_to_var(*tuple(inputs),
                                              cuda=is_on_cuda(model))

            logits = model(*inputs)

            if isinstance(criterion, nn.BCELoss):
                target = target.float()

            total = criterion(logits, target)

            optimizer.zero_grad()
            total.backward()
            optimizer.step()

            losses.append(sequence_to_np(total))

        logger.train(losses, step)

        # validation
        model.eval()

        # metrics
        score = evaluate(model, val_data, val_labels)
        dump_json(score, path / 'val_accuracy.json')
        print(f'Val score {score}')
        logger.metrics({'accuracy': score}, step)

        # best model
        if best_score is None or best_score < score:
            best_score = score
            save_model_state(model, path / 'best_model.pth')

    save_model_state(model, path / 'model.pth')
コード例 #5
0
def train(root,
          dataset,
          generator,
          discriminator,
          latent_dim,
          disc_iters,
          batch_size,
          batches_per_epoch,
          n_epochs,
          r1_weight,
          lr_discriminator,
          lr_generator,
          device='cuda'):
    root = Path(root)
    generator, discriminator = generator.to(device), discriminator.to(device)

    gen_opt = torch.optim.Adam(generator.parameters(), lr_generator)
    disc_opt = torch.optim.Adam(discriminator.parameters(), lr_discriminator)

    batch_iter = Infinite(
        sample(dataset.ids),
        dataset.image,
        combiner=lambda images: [
            tuple(
                np.array(images).reshape(disc_iters,
                                         len(images) // disc_iters, 1, *images[
                                             0].shape))
        ],
        buffer_size=10,
        batch_size=batch_size * disc_iters,
        batches_per_epoch=batches_per_epoch,
    )

    def val():
        generator.eval()
        t = sample_on_sphere(latent_dim, 144, state=0).astype('float32')
        with torch.no_grad():
            y = to_np(generator(to_var(t, device=generator)))[:, 0]

        return {'generated__image': bytescale(stitch(y, bytescale))}

    logger = TBLogger(root / 'logs')
    train_base(
        train_step,
        batch_iter,
        n_epochs,
        logger=logger,
        validate=val,
        checkpoints=Checkpoints(root / 'checkpoints',
                                [generator, discriminator, gen_opt, disc_opt]),
        # lr=lr,
        generator=generator,
        discriminator=discriminator,
        gen_optimizer=gen_opt,
        disc_optimizer=disc_opt,
        latent_dim=latent_dim,
        r1_weight=r1_weight,
        time=TimeProfiler(logger.logger),
        tqdm=TQDM(False),
    )
    save_model_state(generator, root / 'generator.pth')
    save_model_state(discriminator, root / 'discriminator.pth')