def generate_prediction(*, p_pred_step, params, tokenized_prompts, eos_id, inference_rng, decode_tokens, max_predict_length: int): """Generate text from the prompt.""" n_devices = jax.local_device_count() logging.info("Generating text.") predictions = [] # Use batch of prompts provided by user. for pred_batch in jnp.array_split( tokenized_prompts, int(np.ceil(len(tokenized_prompts) / n_devices))): cur_pred_batch_size = pred_batch.shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map(lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) inference_rng, sub_rng = random.split(inference_rng) inference_rngs = random.split(sub_rng, n_devices) predicted = p_pred_step(pred_batch, params, inference_rngs, eos_id, max_predict_length) predicted = tohost(predicted) # Iterate through non-padding examples of batch. for s in predicted[:cur_pred_batch_size]: prediction = decode_tokens(s) logging.info("Sample: %s", str(prediction)) predictions.append(prediction) # Save generated texts for tensorboard. exemplars = "" for prediction in predictions: exemplars += f"{prediction}\n\n" return exemplars
def __call__( self, x: jnp.ndarray, ) -> jnp.ndarray: """Compute (optionally masked) MHA with queries, keys & values.""" all_out = jnp.dot(x, self.in_proj_weight.transpose()) all_out += self.in_proj_bias q, k, v = jnp.array_split(all_out, 3, axis=-1) query_heads = self._split(q) key_heads = self._split(k) value_heads = self._split(v) attention_logits = jnp.einsum("tbhd,Tbhd->bhtT", query_heads, key_heads) sqrt_key_size = np.sqrt(self.model_size // self.num_heads).astype( k.dtype) attention_logits = attention_logits / sqrt_key_size if self.attn_mask is not None: attention_logits += self.attn_mask attention_weights = jax.nn.softmax(attention_logits) attention = jnp.einsum("bhtT,Tbhd->tbhd", attention_weights, value_heads) # Concatenate attention matrix of all heads into a single vector. attention_vec = jnp.reshape(attention, (*q.shape[:2], -1)) return self.out_proj(attention_vec)
def split(self, npartitions): """ Split configs into npartitions new configs objects for parallelization Args: npartitions: int, number of partitions to divide configs into Returns: configslist: list of new configs objects """ return [OpenConfigs(c) for c in jnp.array_split(self.configs, npartitions)]
def shuffled_batched_indices( rng: Key, stream_len: int, batch_size: int, drop_last: bool = False, ): if isinstance(stream_len, list): # stream_len is a sequence of indices already, or a list of objects stream_len = len(stream_len) shuffled = jax.random.permutation(rng, jnp.arange(0, stream_len)) shuffled_batched = jnp.array_split( shuffled, jnp.arange(batch_size, stream_len, batch_size), ) if stream_len % batch_size and drop_last: shuffled_batched = shuffled_batched[:-1] return shuffled_batched
def array_split(ary, indices_or_sections, axis: int = 0): ary = _remove_jaxarray(ary) if isinstance(indices_or_sections, JaxArray): indices_or_sections = indices_or_sections.value return jnp.array_split(ary, indices_or_sections, axis)
def run_model( model_func, data, ep, num_samples=500, num_warmup=500, num_chains=4, target_accept=0.75, max_tree_depth=15, save_results=True, output_fname=None, model_kwargs=None, save_json=False, chain_method="parallel", heuristic_step_size=True, ): """ Model run utility :param model_func: numpyro model :param data: PreprocessedData object :param ep: EpidemiologicalParameters object :param num_samples: number of samples :param num_warmup: number of warmup samples :param num_chains: number of chains :param target_accept: target accept :param max_tree_depth: maximum treedepth :param save_results: whether to save full results :param output_fname: output filename :param model_kwargs: model kwargs -- extra arguments for the model function :param save_json: whether to save json :param chain_method: Numpyro chain method to use :param heuristic_step_size: whether to find a heuristic step size :return: posterior_samples, warmup_samples, info_dict (dict with assorted diagnostics), Numpyro mcmc object """ print( f"Running {num_chains} chains, {num_samples} per chain with {num_warmup} warmup steps" ) nuts_kernel = NUTS( model_func, init_strategy=init_to_median, target_accept_prob=target_accept, max_tree_depth=max_tree_depth, find_heuristic_step_size=heuristic_step_size, ) mcmc = MCMC( nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, num_chains=num_chains, chain_method=chain_method, ) rng_key = random.PRNGKey(0) # hmcstate = nuts_kernel.init(rng_key, 1, model_args=(data, ep)) # nRVs = hmcstate.adapt_state.inverse_mass_matrix.size # inverse_mass_matrix = init_diag_inv_mass_mat * jnp.ones(nRVs) # mass_matrix_sqrt_inv = np.sqrt(inverse_mass_matrix) # mass_matrix_sqrt = 1./mass_matrix_sqrt_inv # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(inverse_mass_matrix=inverse_mass_matrix)) # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt_inv=mass_matrix_sqrt_inv)) # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt=mass_matrix_sqrt)) # mcmc.post_warmup_state = hmcstate info_dict = { "model_name": model_func.__name__, } start = time.time() if model_kwargs is None: model_kwargs = {} info_dict["model_kwargs"] = model_kwargs # also collect some extra information for better diagonstics! print(f"Warmup Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") mcmc.warmup( rng_key, data, ep, **model_kwargs, collect_warmup=True, extra_fields=["num_steps", "mean_accept_prob", "adapt_state"], ) mcmc.get_extra_fields()["num_steps"].block_until_ready() info_dict["warmup"] = {} info_dict["warmup"]["num_steps"] = np.array( mcmc.get_extra_fields()["num_steps"]).tolist() info_dict["warmup"]["step_size"] = np.array( mcmc.get_extra_fields()["adapt_state"].step_size).tolist() info_dict["warmup"]["inverse_mass_matrix"] = {} all_mass_mats = jnp.array( jnp.array_split( mcmc.get_extra_fields()["adapt_state"].inverse_mass_matrix, num_chains, axis=0, )) print(all_mass_mats.shape) for i in range(num_chains): info_dict["warmup"]["inverse_mass_matrix"][ f"chain_{i}"] = all_mass_mats[i, -1, :].tolist() info_dict["warmup"]["mean_accept_prob"] = np.array( mcmc.get_extra_fields()["mean_accept_prob"]).tolist() warmup_samples = mcmc.get_samples() print(f"Sample Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") mcmc.run( rng_key, data, ep, **model_kwargs, extra_fields=["num_steps", "mean_accept_prob", "adapt_state"], ) posterior_samples = mcmc.get_samples() # if you don't block this, the timer won't quite work properly. posterior_samples[list(posterior_samples.keys())[0]].block_until_ready() print(f"Sample Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") end = time.time() time_per_sample = float(end - start) / num_samples divergences = int(mcmc.get_extra_fields()["diverging"].sum()) info_dict["time_per_sample"] = time_per_sample info_dict["total_runtime"] = float(end - start) info_dict["divergences"] = divergences info_dict["sample"] = {} info_dict["sample"]["num_steps"] = np.array( mcmc.get_extra_fields()["num_steps"]).tolist() info_dict["sample"]["mean_accept_prob"] = np.array( mcmc.get_extra_fields()["mean_accept_prob"]).tolist() info_dict["sample"]["step_size"] = np.array( mcmc.get_extra_fields()["adapt_state"].step_size).tolist() print(f"Sampling {num_samples} samples per chain took {end - start:.2f}s") print(f"There were {divergences} divergences.") grouped_posterior_samples = mcmc.get_samples(True) all_ess = np.array([]) for k in grouped_posterior_samples.keys(): ess = numpyro.diagnostics.effective_sample_size( np.asarray(grouped_posterior_samples[k])) all_ess = np.append(all_ess, ess) print(f"{np.sum(np.isnan(all_ess))} ESS were nan") all_ess = all_ess[np.logical_not(np.isnan(all_ess))] info_dict["ess"] = { "med": float(np.percentile(all_ess, 50)), "lower": float(np.percentile(all_ess, 2.5)), "upper": float(np.percentile(all_ess, 97.5)), "min": float(np.min(all_ess)), "max": float(np.max(all_ess)), } print( f"Mean ESS: {info_dict['ess']['med']:.2f} [{info_dict['ess']['lower']:.2f} ... {info_dict['ess']['upper']:.2f}]" ) if num_chains > 1: all_rhat = np.array([]) for k in grouped_posterior_samples.keys(): rhat = numpyro.diagnostics.gelman_rubin( np.asarray(grouped_posterior_samples[k])) all_rhat = np.append(all_rhat, rhat) print(f"{np.sum(np.isnan(all_rhat))} Rhat were nan") all_rhat = all_rhat[np.logical_not(np.isnan(all_rhat))] info_dict["rhat"] = { "med": float(np.percentile(all_rhat, 50)), "upper": float(np.percentile(all_rhat, 97.5)), "lower": float(np.percentile(all_rhat, 2.5)), "min": float(np.max(all_rhat)), "max": float(np.min(all_rhat)), } print( f"Rhat: {info_dict['rhat']['med']:.2f} [{info_dict['rhat']['lower']:.2f} ... {info_dict['rhat']['upper']:.2f}]" ) if save_results: print("Saving .netcdf") try: inf_data = az.from_numpyro(mcmc) if output_fname is None: output_fname = f'{model_func.__name__}-{datetime.now(tz=None).strftime("%d-%m;%H-%M-%S")}.netcdf' az.to_netcdf(inf_data, output_fname) json_fname = output_fname.replace(".netcdf", ".json") if save_json: print("Saving Json") with open(json_fname, "w") as f: json.dump(info_dict, f, ensure_ascii=False, indent=4) except Exception as e: print(e) return posterior_samples, warmup_samples, info_dict, mcmc
def decollapse_and_split(self, x): # Decollapse batches and split alpha from color channels x = jnp.reshape(x, (x.shape[0]//self.num_slots, self.num_slots, *x.shape[1:])) # Decollapse batches from slots x, alphas = jnp.array_split(x, [x.shape[-1]-1], -1) return x, alphas