Beispiel #1
0
    def get_parameter_and_state_names(self, layer):

        # Store the names of the parameters for the scan loop
        with make_functional_modules([layer]) as ([apply_fun], \
                                                   params, \
                                                   (state, constants, rng_seq), \
                                                   finalize):
            bundle_name = current_bundle_name()

            # Filter out the params and states that aren't a part of this repeat
            filtered_params = {
                key: val
                for (key, val) in params.items() if key.startswith(bundle_name)
            }
            filtered_state = {
                key: val
                for (key, val) in state.items() if key.startswith(bundle_name)
            }

            # Order the parameters correctly and separate the keys from values
            sorted_params = sorted(filtered_params.items(), key=lambda x: x[0])
            sorted_state = sorted(filtered_state.items(), key=lambda x: x[0])

            param_names, param_vals = zip(*sorted_params)
            param_shapes = util.tree_shapes(param_vals)
            if len(sorted_state) == 0:
                state_names = ()
                state_shapes = ()
            else:
                state_names, state_vals = zip(*sorted_state)
                state_shapes = util.tree_shapes(state_vals)

            finalize(params, (state, constants, rng_seq))

        return (param_names, state_names), (param_shapes, state_shapes)
Beispiel #2
0
  def flow_norm_init(self,
                     inputs: Mapping[str, jnp.ndarray],
                     rng: jnp.ndarray=None,
                     sample: Optional[bool]=False,
                     **kwargs):
    """ Initialize this layer so that its outputs are normally distributed
    """

    # Check if we've set flow norm (will be False the first time)
    flow_norm_set = get_constant("flow_norm_set", False, do_not_set=True)

    # Set that we're checking
    get_constant("flow_norm_set", True, do_not_set=False)

    if not flow_norm_set:
      # Train this layer over the input batch to generate a unit normal output

      def loss_fun(inputs, rng, sample=False, **kwargs):
        outputs = self(inputs, rng, sample=sample, **kwargs)
        z = outputs["x"]

        @self.auto_batch
        def unit_gaussian(z):
          return -0.5*jnp.sum(z.ravel()**2) # + const
        log_pz = unit_gaussian(z)

        log_px = log_pz + outputs["log_det"]
        return -log_px.mean()

      with make_functional_modules([loss_fun]) as ([apply_fun], \
                                                   params, \
                                                   state, \
                                                   finalize_params_and_state):
        import optax
        opt_init, opt_update = optax.adam(learning_rate=1e-4)
        opt_state = opt_init(params)
        opt_update = jit(opt_update)

        grad_fun = jax.value_and_grad(apply_fun, has_aux=True)
        grad_fun = partial(grad_fun, sample=sample, **kwargs)
        grad_fun = jit(grad_fun)

        import tqdm
        pbar = tqdm.tqdm(list(enumerate(random.split(rng, 200))))
        for i, rng in pbar:
          (loss, state), grad = grad_fun(params, state, inputs, rng)
          updates, opt_state = opt_update(grad, opt_state, params)
          if jnp.any(jnp.isnan(jax.flatten_util.ravel_pytree(updates)[0])):
            break
          params = jit(optax.apply_updates)(params, updates)

          pbar.set_description(f"loss: {loss}")

        finalize_params_and_state(params, state)
Beispiel #3
0
    def forward(self, x, rng):
        self.init_if_needed(x, rng)

        batch_info = self.unbatched_input_shapes["x"], self.batch_shape

        fun = partial(self.auto_batched_res_block, update_params=False)
        with make_functional_modules([fun]) as ([apply_fun], \
                                                 params, \
                                                 state, \
                                                 finalize):
            if self.use_trace_estimator:
                z, log_det = res_flow_sliced_estimate(apply_fun, params, state,
                                                      x, rng, batch_info)
            else:
                z, log_det = res_flow_estimate(apply_fun, params, state, x,
                                               rng, batch_info)

            finalize(params, state)

        return z, log_det
Beispiel #4
0
    def invert(self, z, rng):
        self.init_if_needed(z, rng)

        # State will be held constant during the fixed point iterations
        fun = partial(self.auto_batched_res_block, update_params=False)

        with make_functional_modules([fun]) as ([apply_fun], \
                                                params, \
                                                state, \
                                                finalize):
            # Make sure we don't use a different random key at every step of the fixed point iterations.
            deterministic_apply_fun = lambda params, state, x: apply_fun(
                params, state, x, rng)

            # Run the fixed point iterations to invert at z.  We can do reverse-mode through this!
            x = fixed_point(deterministic_apply_fun, params, state, z, rng)

            # Update the Haiku global states
            finalize(params, state)

        return x
Beispiel #5
0
    def forward(self, x, rng, update_params):
        self.init_if_needed(x, rng)

        batch_info = self.unbatched_input_shapes["x"], self.batch_shape

        with make_functional_modules([self.auto_batched_res_block]) as ([apply_fun], \
                                                                        params, \
                                                                        state, \
                                                                        finalize):
            if self.use_trace_estimator:
                z, log_det, state = res_flow_sliced_estimate(
                    apply_fun, params, state, x, rng, batch_info)
            else:
                z, log_det, state = res_flow_estimate(apply_fun, params, state,
                                                      x, rng, batch_info)

            # Ensure that we don't backprop through state (this shouldn't affect anything)
            state = jax.lax.stop_gradient(state)

            # Update the Haiku global states
            finalize(params, state)

        return z, log_det
Beispiel #6
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             no_scan: bool = False,
             accumulate: Iterable[str] = ["log_det", "aux_loss"],
             **kwargs) -> Mapping[str, jnp.ndarray]:
        if Layer._is_initializing:
            return self.call_no_scan(inputs, rng, sample=sample, **kwargs)

        # Want to make sure that we're passing all inputs/outputs to the next layer
        final_outputs = inputs.copy()

        # Need to get the funcitonal apply fun
        with make_functional_modules([self.layer_create_fun()]) as ([apply_fun], \
                                                                    params, \
                                                                    (state, constants, rng_seq), \
                                                                    finalize):
            # Retrieve the hashes of the names of the parameters and states for the layer call
            param_hashes, state_hashes = get_constant(
                "param_state_name_hashes", None)

            # Batch together the parameters and state across the repeated layers
            scan_params = _batch_repeated_layers(params, param_hashes)
            scan_params = data_structures.to_immutable_dict(scan_params)
            scan_state = _batch_repeated_layers(state, state_hashes)

            # Reverse the order if we are sampling
            if sample == True:
                scan_params = jax.tree_map(lambda x: x[::-1], scan_params)
                scan_state = jax.tree_map(lambda x: x[::-1], scan_state)

            # Pass other inputs we might have through the network
            shared_inputs = inputs.copy()
            del shared_inputs["x"]

            # Use a scan loop so that we only need to compile layer once!
            def scan_body(carry, scan_inputs):
                x = carry
                params, state, rng = scan_inputs

                # Bundle the non-parameter state together
                bundled_state = (state, constants, rng_seq)

                # Make sure that we're passing all of the inputs (such as labels) to the layer
                inputs = shared_inputs.copy()
                inputs["x"] = x

                # Run the function
                outputs, bundled_state = apply_fun(params,
                                                   bundled_state,
                                                   inputs,
                                                   rng,
                                                   sample=sample,
                                                   **kwargs)

                # Retrieve the state because it might have changed
                state, _, _ = bundled_state

                # Return the stuff we need
                x = outputs["x"]
                del outputs["x"]
                return x, (outputs, state)

            # Run the scan function
            rngs = random.split(
                rng,
                self.n_repeats) if rng is not None else [None] * self.n_repeats
            x, (batched_outputs, batched_updated_state) = jax.lax.scan(
                scan_body, inputs["x"], (scan_params, scan_state, rngs))

            # Reverse the updated state if we are sampling
            if sample == True:
                batched_updated_state = jax.tree_map(lambda x: x[::-1],
                                                     batched_updated_state)

            # Search through the outputs to find things we want to accumulate
            accumulated_outputs = {}
            for name in accumulate:
                if name in batched_outputs:
                    accumulated_outputs[name] = batched_outputs[name].sum(
                        axis=0)
                    del batched_outputs[name]

            # Convert the output of the scan into the same state data structure that was passed in.
            hash_map = {hash(k): k for k in state.keys()}
            rev_hash_map = {k: hash(k) for k in state.keys()}
            updated_state = state.copy()
            for base_layer_name, pytree in batched_updated_state.items():

                # Retrieve the names of each repeated layer
                layer_names = [
                    hash_map[k]
                    for k in state_hashes[rev_hash_map[base_layer_name]]
                ]

                # Split the batched parameters
                leaves, treedef = jax.tree_flatten(
                    batched_updated_state[base_layer_name])
                split_states = [
                    jax.tree_unflatten(treedef, [l[i] for l in leaves])
                    for i in range(self.n_repeats)
                ]

                # Update the state dictionary
                updated_state.update(dict(zip(layer_names, split_states)))

            # Just in case
            updated_state = jax.lax.stop_gradient(updated_state)

            # Only state might be different
            bundled_state = (updated_state, constants, rng_seq)
            finalize(params, bundled_state)

        outputs = {"x": x}
        outputs.update(accumulated_outputs)

        return outputs