Exemple #1
0
def generate_sample(pcnn_module, batch_size, rng_seed=0):
    rng = random.PRNGKey(rng_seed)
    rng, model_rng = random.split(rng)

    # Create a model with dummy parameters and a dummy optimizer
    example_images = jnp.zeros((1, 32, 32, 3))
    model = train.create_model(model_rng, example_images, pcnn_module)
    optimizer = train.create_optimizer(model, 0)

    # Load learned parameters
    _, ema = train.restore_checkpoint(optimizer, model.params)
    model = model.replace(params=ema)

    # Initialize batch of images
    device_count = jax.local_device_count()
    assert not batch_size % device_count, (
        'Sampling batch size must be a multiple of the device count, got '
        'sample_batch_size={}, device_count={}.'.format(
            batch_size, device_count))
    sample_prev = jnp.zeros(
        (device_count, batch_size // device_count, 32, 32, 3))

    # and batch of rng keys
    sample_rng = random.split(rng, device_count)

    # Generate sample using fixed-point iteration
    sample = sample_iteration(sample_rng, model, sample_prev)
    while jnp.any(sample != sample_prev):
        sample_prev, sample = sample, sample_iteration(sample_rng, model,
                                                       sample)
    return jnp.reshape(sample, (batch_size, 32, 32, 3))
Exemple #2
0
def generate_sample(config: ml_collections.ConfigDict, workdir: str):
    """Loads the latest model in `workdir` and samples a batch of images."""
    batch_size = config.sample_batch_size
    rng = random.PRNGKey(config.sample_rng_seed)
    rng, model_rng = random.split(rng)
    rng, dropout_rng = random.split(rng)

    # Create a model with dummy parameters and a dummy optimizer.
    init_batch = jnp.zeros((1, 32, 32, 3))

    params = train.model(config).init(
        {
            'params': model_rng,
            'dropout': dropout_rng
        }, init_batch)['params']
    optimizer_def = optim.Adam(learning_rate=config.learning_rate,
                               beta1=0.95,
                               beta2=0.9995)
    optimizer = optimizer_def.create(params)

    _, params = train.restore_checkpoint(workdir, optimizer, params)

    # Initialize batch of images
    device_count = jax.local_device_count()
    assert not batch_size % device_count, (
        'Sampling batch size must be a multiple of the device count, got '
        'sample_batch_size={}, device_count={}.'.format(
            batch_size, device_count))
    sample_prev = jnp.zeros(
        (device_count, batch_size // device_count, 32, 32, 3))

    # and batch of rng keys
    sample_rng = random.split(rng, device_count)

    # Generate sample using fixed-point iteration
    sample = sample_iteration(config, sample_rng, params, sample_prev)
    while jnp.any(sample != sample_prev):
        sample_prev, sample = sample, sample_iteration(config, sample_rng,
                                                       params, sample)
    return jnp.reshape(sample, (batch_size, 32, 32, 3))
Exemple #3
0
    if generate_txt:
        answer_file.close()

    return np.mean(aucs), np.mean(mrrs), np.mean(ndcg5s), np.mean(ndcg10s)


if __name__ == '__main__':
    # avoid circular import
    from train import parse_arguments, get_model, restore_checkpoint

    parser = argparse.ArgumentParser(description='Eval params')
    config = parse_arguments(parser)

    model = get_model(config)
    model, is_sucessfull = restore_checkpoint(config, model, is_train=False)

    if not is_sucessfull:
        print('No checkpoint file found!')
        exit()

    prediction_folder = f'{config.val_dir}/{config.model_name}'
    Path(prediction_folder).mkdir(parents=True, exist_ok=True)
    if config.model_name.startswith('DM'):
        auc, mrr, ndcg5, ndcg10 = evaluate_dm(config,
                                              model,
                                              config.dev_dir,
                                              config.train_dir,
                                              generate_txt=True,
                                              txt_path=prediction_folder +
                                              '/prediction.txt',