def generate_sample(pcnn_module, batch_size, rng_seed=0): rng = random.PRNGKey(rng_seed) rng, model_rng = random.split(rng) # Create a model with dummy parameters and a dummy optimizer example_images = jnp.zeros((1, 32, 32, 3)) model = train.create_model(model_rng, example_images, pcnn_module) optimizer = train.create_optimizer(model, 0) # Load learned parameters _, ema = train.restore_checkpoint(optimizer, model.params) model = model.replace(params=ema) # Initialize batch of images device_count = jax.local_device_count() assert not batch_size % device_count, ( 'Sampling batch size must be a multiple of the device count, got ' 'sample_batch_size={}, device_count={}.'.format( batch_size, device_count)) sample_prev = jnp.zeros( (device_count, batch_size // device_count, 32, 32, 3)) # and batch of rng keys sample_rng = random.split(rng, device_count) # Generate sample using fixed-point iteration sample = sample_iteration(sample_rng, model, sample_prev) while jnp.any(sample != sample_prev): sample_prev, sample = sample, sample_iteration(sample_rng, model, sample) return jnp.reshape(sample, (batch_size, 32, 32, 3))
def generate_sample(config: ml_collections.ConfigDict, workdir: str): """Loads the latest model in `workdir` and samples a batch of images.""" batch_size = config.sample_batch_size rng = random.PRNGKey(config.sample_rng_seed) rng, model_rng = random.split(rng) rng, dropout_rng = random.split(rng) # Create a model with dummy parameters and a dummy optimizer. init_batch = jnp.zeros((1, 32, 32, 3)) params = train.model(config).init( { 'params': model_rng, 'dropout': dropout_rng }, init_batch)['params'] optimizer_def = optim.Adam(learning_rate=config.learning_rate, beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(params) _, params = train.restore_checkpoint(workdir, optimizer, params) # Initialize batch of images device_count = jax.local_device_count() assert not batch_size % device_count, ( 'Sampling batch size must be a multiple of the device count, got ' 'sample_batch_size={}, device_count={}.'.format( batch_size, device_count)) sample_prev = jnp.zeros( (device_count, batch_size // device_count, 32, 32, 3)) # and batch of rng keys sample_rng = random.split(rng, device_count) # Generate sample using fixed-point iteration sample = sample_iteration(config, sample_rng, params, sample_prev) while jnp.any(sample != sample_prev): sample_prev, sample = sample, sample_iteration(config, sample_rng, params, sample) return jnp.reshape(sample, (batch_size, 32, 32, 3))
if generate_txt: answer_file.close() return np.mean(aucs), np.mean(mrrs), np.mean(ndcg5s), np.mean(ndcg10s) if __name__ == '__main__': # avoid circular import from train import parse_arguments, get_model, restore_checkpoint parser = argparse.ArgumentParser(description='Eval params') config = parse_arguments(parser) model = get_model(config) model, is_sucessfull = restore_checkpoint(config, model, is_train=False) if not is_sucessfull: print('No checkpoint file found!') exit() prediction_folder = f'{config.val_dir}/{config.model_name}' Path(prediction_folder).mkdir(parents=True, exist_ok=True) if config.model_name.startswith('DM'): auc, mrr, ndcg5, ndcg10 = evaluate_dm(config, model, config.dev_dir, config.train_dir, generate_txt=True, txt_path=prediction_folder + '/prediction.txt',