def get_parameter_and_state_names(self, layer): # Store the names of the parameters for the scan loop with make_functional_modules([layer]) as ([apply_fun], \ params, \ (state, constants, rng_seq), \ finalize): bundle_name = current_bundle_name() # Filter out the params and states that aren't a part of this repeat filtered_params = { key: val for (key, val) in params.items() if key.startswith(bundle_name) } filtered_state = { key: val for (key, val) in state.items() if key.startswith(bundle_name) } # Order the parameters correctly and separate the keys from values sorted_params = sorted(filtered_params.items(), key=lambda x: x[0]) sorted_state = sorted(filtered_state.items(), key=lambda x: x[0]) param_names, param_vals = zip(*sorted_params) param_shapes = util.tree_shapes(param_vals) if len(sorted_state) == 0: state_names = () state_shapes = () else: state_names, state_vals = zip(*sorted_state) state_shapes = util.tree_shapes(state_vals) finalize(params, (state, constants, rng_seq)) return (param_names, state_names), (param_shapes, state_shapes)
def flow_norm_init(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs): """ Initialize this layer so that its outputs are normally distributed """ # Check if we've set flow norm (will be False the first time) flow_norm_set = get_constant("flow_norm_set", False, do_not_set=True) # Set that we're checking get_constant("flow_norm_set", True, do_not_set=False) if not flow_norm_set: # Train this layer over the input batch to generate a unit normal output def loss_fun(inputs, rng, sample=False, **kwargs): outputs = self(inputs, rng, sample=sample, **kwargs) z = outputs["x"] @self.auto_batch def unit_gaussian(z): return -0.5*jnp.sum(z.ravel()**2) # + const log_pz = unit_gaussian(z) log_px = log_pz + outputs["log_det"] return -log_px.mean() with make_functional_modules([loss_fun]) as ([apply_fun], \ params, \ state, \ finalize_params_and_state): import optax opt_init, opt_update = optax.adam(learning_rate=1e-4) opt_state = opt_init(params) opt_update = jit(opt_update) grad_fun = jax.value_and_grad(apply_fun, has_aux=True) grad_fun = partial(grad_fun, sample=sample, **kwargs) grad_fun = jit(grad_fun) import tqdm pbar = tqdm.tqdm(list(enumerate(random.split(rng, 200)))) for i, rng in pbar: (loss, state), grad = grad_fun(params, state, inputs, rng) updates, opt_state = opt_update(grad, opt_state, params) if jnp.any(jnp.isnan(jax.flatten_util.ravel_pytree(updates)[0])): break params = jit(optax.apply_updates)(params, updates) pbar.set_description(f"loss: {loss}") finalize_params_and_state(params, state)
def forward(self, x, rng): self.init_if_needed(x, rng) batch_info = self.unbatched_input_shapes["x"], self.batch_shape fun = partial(self.auto_batched_res_block, update_params=False) with make_functional_modules([fun]) as ([apply_fun], \ params, \ state, \ finalize): if self.use_trace_estimator: z, log_det = res_flow_sliced_estimate(apply_fun, params, state, x, rng, batch_info) else: z, log_det = res_flow_estimate(apply_fun, params, state, x, rng, batch_info) finalize(params, state) return z, log_det
def invert(self, z, rng): self.init_if_needed(z, rng) # State will be held constant during the fixed point iterations fun = partial(self.auto_batched_res_block, update_params=False) with make_functional_modules([fun]) as ([apply_fun], \ params, \ state, \ finalize): # Make sure we don't use a different random key at every step of the fixed point iterations. deterministic_apply_fun = lambda params, state, x: apply_fun( params, state, x, rng) # Run the fixed point iterations to invert at z. We can do reverse-mode through this! x = fixed_point(deterministic_apply_fun, params, state, z, rng) # Update the Haiku global states finalize(params, state) return x
def forward(self, x, rng, update_params): self.init_if_needed(x, rng) batch_info = self.unbatched_input_shapes["x"], self.batch_shape with make_functional_modules([self.auto_batched_res_block]) as ([apply_fun], \ params, \ state, \ finalize): if self.use_trace_estimator: z, log_det, state = res_flow_sliced_estimate( apply_fun, params, state, x, rng, batch_info) else: z, log_det, state = res_flow_estimate(apply_fun, params, state, x, rng, batch_info) # Ensure that we don't backprop through state (this shouldn't affect anything) state = jax.lax.stop_gradient(state) # Update the Haiku global states finalize(params, state) return z, log_det
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, no_scan: bool = False, accumulate: Iterable[str] = ["log_det", "aux_loss"], **kwargs) -> Mapping[str, jnp.ndarray]: if Layer._is_initializing: return self.call_no_scan(inputs, rng, sample=sample, **kwargs) # Want to make sure that we're passing all inputs/outputs to the next layer final_outputs = inputs.copy() # Need to get the funcitonal apply fun with make_functional_modules([self.layer_create_fun()]) as ([apply_fun], \ params, \ (state, constants, rng_seq), \ finalize): # Retrieve the hashes of the names of the parameters and states for the layer call param_hashes, state_hashes = get_constant( "param_state_name_hashes", None) # Batch together the parameters and state across the repeated layers scan_params = _batch_repeated_layers(params, param_hashes) scan_params = data_structures.to_immutable_dict(scan_params) scan_state = _batch_repeated_layers(state, state_hashes) # Reverse the order if we are sampling if sample == True: scan_params = jax.tree_map(lambda x: x[::-1], scan_params) scan_state = jax.tree_map(lambda x: x[::-1], scan_state) # Pass other inputs we might have through the network shared_inputs = inputs.copy() del shared_inputs["x"] # Use a scan loop so that we only need to compile layer once! def scan_body(carry, scan_inputs): x = carry params, state, rng = scan_inputs # Bundle the non-parameter state together bundled_state = (state, constants, rng_seq) # Make sure that we're passing all of the inputs (such as labels) to the layer inputs = shared_inputs.copy() inputs["x"] = x # Run the function outputs, bundled_state = apply_fun(params, bundled_state, inputs, rng, sample=sample, **kwargs) # Retrieve the state because it might have changed state, _, _ = bundled_state # Return the stuff we need x = outputs["x"] del outputs["x"] return x, (outputs, state) # Run the scan function rngs = random.split( rng, self.n_repeats) if rng is not None else [None] * self.n_repeats x, (batched_outputs, batched_updated_state) = jax.lax.scan( scan_body, inputs["x"], (scan_params, scan_state, rngs)) # Reverse the updated state if we are sampling if sample == True: batched_updated_state = jax.tree_map(lambda x: x[::-1], batched_updated_state) # Search through the outputs to find things we want to accumulate accumulated_outputs = {} for name in accumulate: if name in batched_outputs: accumulated_outputs[name] = batched_outputs[name].sum( axis=0) del batched_outputs[name] # Convert the output of the scan into the same state data structure that was passed in. hash_map = {hash(k): k for k in state.keys()} rev_hash_map = {k: hash(k) for k in state.keys()} updated_state = state.copy() for base_layer_name, pytree in batched_updated_state.items(): # Retrieve the names of each repeated layer layer_names = [ hash_map[k] for k in state_hashes[rev_hash_map[base_layer_name]] ] # Split the batched parameters leaves, treedef = jax.tree_flatten( batched_updated_state[base_layer_name]) split_states = [ jax.tree_unflatten(treedef, [l[i] for l in leaves]) for i in range(self.n_repeats) ] # Update the state dictionary updated_state.update(dict(zip(layer_names, split_states))) # Just in case updated_state = jax.lax.stop_gradient(updated_state) # Only state might be different bundled_state = (updated_state, constants, rng_seq) finalize(params, bundled_state) outputs = {"x": x} outputs.update(accumulated_outputs) return outputs