コード例 #1
0
def generate_data_01():
    batch_size = 8
    input_shape = (batch_size, 4)

    def synth_batches():
        while True:
            images = npr.rand(*input_shape).astype("float32")
            yield images

    batches = synth_batches()
    inputs = next(batches)

    init_func, predict_func = stax.serial(
        HomotopyDense(out_dim=4, W_init=glorot_uniform(), b_init=normal()),
        HomotopyDense(out_dim=1, W_init=glorot_uniform(), b_init=normal()),
        Sigmoid,
    )

    ae_shape, ae_params = init_func(random.PRNGKey(0), input_shape)
    # assert ae_shape == input_shape
    bparam = [np.array([0.0], dtype=np.float64)]
    logits = predict_func(ae_params,
                          inputs,
                          bparam=bparam[0],
                          activation_func=sigmoid)
    loss = np.mean(
        (np.subtract(logits, logits))) + l2_norm(ae_params) + l2_norm(bparam)

    return inputs, logits, ae_params, bparam, init_func, predict_func
コード例 #2
0
 def objective(params, bparam) -> float:
     logits = predict_fun(params,
                          inputs,
                          bparam=bparam[0],
                          activation_func=sigmoid)
     loss = np.mean((np.subtract(logits, outputs)))
     loss += l2_norm(params) + l2_norm(bparam)
     return loss
コード例 #3
0
    def objective(params, bparam, batch) -> float:
        x, _ = batch
        x = np.reshape(x, (x.shape[0], -1))
        logits = predict_fun(params, x, bparam=bparam[0], rng=key)
        keep = random.bernoulli(key, bparam[0], x.shape)
        inputs_d = np.where(keep, x, 0)

        loss = np.mean(np.square((np.subtract(logits, inputs_d))))
        loss += 1e-8 * (l2_norm(params) + 1 / l2_norm(bparam))
        return loss
コード例 #4
0
    def testUtilityClipGrads(self):
        g = (np.ones(2), (np.ones(3), np.ones(4)))
        norm = optimizers.l2_norm(g)

        ans = optimizers.clip_grads(g, 1.1 * norm)
        expected = g
        self.assertAllClose(ans, expected, check_dtypes=False)

        ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
        expected = 0.9 * norm
        self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #5
0
ファイル: conv_nn.py プロジェクト: harsh306/continuation-jax
    def objective(params, bparam) -> float:
        def cross_entropy_loss(logits, labels):
            one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
            return -np.mean(np.sum(one_hot_labels * logits, axis=-1))

        logits = CNN().apply({"params": params[0]}, inputs)
        loss = cross_entropy_loss(logits, outputs)
        loss += l2_norm(params) + l2_norm(bparam)
        # vectorization of mini-batch of data.
        # #3rd argument's 0th-axis is vmaped. --> inputs(10)
        # batched_predict = vmap(neural_net_predict, in_axes=(None, None, 0))
        return loss
コード例 #6
0
    def update(self, params, opt_state, batch: util.Transition):
        """The actual update function."""
        (_, logs), grads = jax.value_and_grad(self._loss, has_aux=True)(params,
                                                                        batch)

        grad_norm_unclipped = optimizers.l2_norm(grads)
        updates, updated_opt_state = self._opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        weight_norm = optimizers.l2_norm(params)
        logs.update({
            'grad_norm_unclipped': grad_norm_unclipped,
            'weight_norm': weight_norm,
        })
        return params, updated_opt_state, logs
コード例 #7
0
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate):
    """Compute the training loss of the LFADS autoencoder

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_bxt: np array of input with leading dims being batch and time
    keep_rate: dropout keep rate
    kl_scale: scale on KL

  Returns:
    a dictionary of all losses, including the key 'total' used for optimization
  """

    B = lfads_hps['batch_size']
    key, skeys = utils.keygen(key, 2)
    keys_b = random.split(next(skeys), B)
    lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate)

    # Sum over time and state dims, average over batch.
    # KL - g0
    ic_post_mean_b = lfads['ic_mean']
    ic_post_logvar_b = lfads['ic_logvar']
    kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b,
                                              params['ic_prior'],
                                              lfads_hps['var_min'])
    kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B
    kl_loss_g0 = kl_scale * kl_loss_g0_prescale

    # KL - Inferred input
    ii_post_mean_bxt = lfads['ii_mean_t']
    ii_post_var_bxt = lfads['ii_logvar_t']
    keys_b = random.split(next(skeys), B)
    kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys_b, ii_post_mean_bxt,
                                            ii_post_var_bxt,
                                            params['ii_prior'],
                                            lfads_hps['var_min'])
    kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B
    kl_loss_ii = kl_scale * kl_loss_ii_prescale

    # Log-likelihood of data given latents.
    lograte_bxt = lfads['lograte_t']
    log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B

    # L2
    l2reg = lfads_hps['l2reg']
    l2_loss = l2reg * optimizers.l2_norm(params)**2

    loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss
    all_losses = {
        'total': loss,
        'nlog_p_xgz': -log_p_xgz,
        'kl_g0': kl_loss_g0,
        'kl_g0_prescale': kl_loss_g0_prescale,
        'kl_ii': kl_loss_ii,
        'kl_ii_prescale': kl_loss_ii_prescale,
        'l2': l2_loss
    }
    return all_losses
コード例 #8
0
def objective(params, bparam, batch) -> float:
    x, targets = batch
    x = np.reshape(x, (x.shape[0], -1))
    logits = predict_fun(params, x, bparam=bparam[0], activation_func=relu)
    loss = -np.mean(np.sum(logits * targets, axis=1))
    loss += 5e-7 * (l2_norm(params))  #+ l2_norm(bparam))
    return loss
コード例 #9
0
def loss(params, inputs_bxtxu, targets_bxtxo, l2reg):
  """Compute the least squares loss of the output, plus L2 regularization."""
  _, outs_bxtxo = batched_rnn_run(params, inputs_bxtxu)
  l2_loss = l2reg * optimizers.l2_norm(params)**2
  lms_loss = np.mean((outs_bxtxo - targets_bxtxo)**2)
  total_loss = lms_loss + l2_loss
  return {'total' : total_loss, 'lms' : lms_loss, 'l2' : l2_loss}
コード例 #10
0
    def __init__(
        self,
        state,
        bparam,
        state_0,
        bparam_0,
        counter,
        objective,
        dual_objective,
        hparams,
    ):
        # states
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(
            bparam, counter
        )  # Todo : save tree def, always unlfatten before compute_grads
        self._prev_state = state_0
        self._prev_bparam = bparam_0

        # objectives
        self.objective = objective
        self.dual_objective = dual_objective
        self.value_func = jit(self.objective)

        self.hparams = hparams

        self._value_wrap = StateVariable(
            1.0, counter)  # TODO: fix with a static batch (test/train)
        self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state),
                                           counter)

        # optimizer
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.ascent_opt = OptimizerCreator(
            opt_string=hparams["meta"]["ascent_optimizer"],
            learning_rate=hparams["ascent_lr"],
        ).get_optimizer()

        # every step hparams
        self.continuation_steps = hparams["continuation_steps"]
        self._lagrange_multiplier = hparams["lagrange_init"]

        self._delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]

        # grad functions # should be pure functional
        self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1]))
        self.compute_max_grad_fn = jit(grad(self.dual_objective, [2]))
        self.compute_grad_fn = jit(grad(self.objective, [0]))

        # extras
        self.sw = None
        self.state_tree_def = None
        self.bparam_tree_def = None
        self.output_file = hparams["meta"]["output_dir"]
        self.prev_secant_direction = None
コード例 #11
0
def loss(params, batch):
    """
    The idxes of the batch indicate which nodes are used to compute the loss.
    """
    inputs, targets, adj, is_training, rng, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt
    return ce_loss + l2_loss
コード例 #12
0
def grouper(iterable, threshold=0.01):
    prev = None
    group = []
    for item in iterable:
        if not prev or l2_norm(pytree_sub(item, prev)) <= threshold:
            group.append(item)
        else:
            yield group
            group = [item]
        prev = item
    if group:
        yield group
コード例 #13
0
    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """

        quality = 1.0
        ma_loss = []
        stop = False
        print("learn_rate", self.opt.lr)
        for k in range(self.warmup_period):
            for b_j in range(self.num_batches):
                batch = next(self.data_loader)
                grads = self.grad_fn(self._state, self._bparam, batch)
                self._state = self.opt.update_params(self._state, grads[0])
                quality = l2_norm(grads)
                value = self.value_fn(self._state, self._bparam, batch)
                ma_loss.append(value)
                self.opt.lr = exp_decay(k, self.hparams["natural_lr"])
                if self.hparams["local_test_measure"] == "norm_gradients":
                    if quality > self.hparams["quality_thresh"]:
                        pass
                        print(f"quality {quality}, {self.opt.lr} ,{k}")
                    else:
                        stop = True
                        print(f"quality {quality} stopping at , {k}th step")
                else:
                    if len(ma_loss) >= 20:
                        tmp_means = running_mean(ma_loss, 10)
                        if math.isclose(
                                tmp_means[-1],
                                tmp_means[-2],
                                abs_tol=self.hparams["loss_tol"],
                        ):
                            print(f"stopping at , {k}th step")
                            stop = True
            if stop:
                print("breaking")
                break

        val_loss = self.value_fn(self._state, self._bparam,
                                 (self.test_images, self.test_labels))
        val_acc = self.accuracy_fn(self._state, self._bparam,
                                   (self.test_images, self.test_labels))
        return self._state, self._bparam, quality, value, val_loss, val_acc
コード例 #14
0
ファイル: rnn.py プロジェクト: LPompe/TorchDynamics
def loss(params, inputs_bxtxu, targets_bxtxo, targets_mask_t, l2reg):
    """Compute the least squares loss of the output, plus L2 regularization.

  Args:
    params: dict RNN parameters
    inputs_bxtxu: np array of inputs batch x time x input dim
    targets_bxtxo: np array of targets batx x time x output dim
    targets_mask_t: list of time indices where target is active
    l2reg: float, hyper parameter controlling strength of L2 regularization

  Returns:
    dict of losses
  """
    _, outs_bxtxo = batched_rnn_run(params, inputs_bxtxu)
    l2_loss = l2reg * optimizers.l2_norm(params)**2
    outs_bxsxo = outs_bxtxo[:, targets_mask_t, :]
    targets_bxsxo = targets_bxtxo[:, targets_mask_t, :]
    lms_loss = np.mean((outs_bxsxo - targets_bxsxo)**2)
    total_loss = lms_loss + l2_loss
    return {'total': total_loss, 'lms': lms_loss, 'l2': l2_loss}
コード例 #15
0
def loss_fn(params, data, rng, batch_size, ic_prior, var_min, kl_scale, l2reg):
    """

    :param params:
    :param data:
    :param rng:
    :param batch_size:
    :param ic_prior:
    :param var_min:
    :param kl_scale:
    :param l2reg:
    :return:
    """

    keys = random.split(rng, batch_size)

    # Run the data through the model!
    result = lfads_batch(params, keys, data)

    # Get KL Loss

    kl_loss_g0 = dists.batch_kl_gauss_gauss(result['ic_post_mean'],
                                            result['ic_post_logvar'], ic_prior,
                                            var_min)
    kl_loss_g0 = np.sum(kl_loss_g0) / batch_size
    kl_loss_g0 = kl_scale * kl_loss_g0

    # Log-likelihood of data given neuron_log_rates.
    log_p_xgz = np.sum(
        dists.poisson_log_likelihood(data,
                                     result['neuron_log_rates'])) / batch_size

    # L2
    l2_loss = l2reg * optimizers.l2_norm(params)**2

    total_loss = -log_p_xgz + kl_loss_g0 + l2_loss

    return total_loss
コード例 #16
0
def ssvm_loss(params,
              x,
              y,
              lamb=0.01,
              max_steps=80,
              step_size=0.1,
              pretrain_global_energy=False):
    prediction = y is None
    x_hat = compute_feature_energy(params, x)
    if pretrain_global_energy:
        x_hat = lax.stop_gradient(x_hat)
    grad_fun = inference_step if prediction else cost_augmented_inference_step

    opt_init, opt_update, get_params = momentum(0.01, 0.95)
    # opt_state = opt_init(np.full(x.shape[:-1] + (LABELS,), 1. / LABELS))
    opt_state = opt_init(np.zeros(x.shape[:-1] + (LABELS, )))
    prev_energy = None
    for step in range(max_steps):
        y_hat = project(get_params(opt_state))
        g, energy = grad_fun(y_hat, y, x_hat, params)
        opt_state = opt_update(step, g, opt_state)
        if step > 0 and check_saddle_point(step, get_params(opt_state), y_hat,
                                           energy, prev_energy):
            break
        prev_energy = energy

    y_hat = lax.stop_gradient(project(get_params(opt_state)))
    if prediction:
        return y_hat

    y = lax.stop_gradient(y)
    pred_energy = compute_global_energy(params, x_hat, y_hat)
    true_energy = compute_global_energy(params, x_hat, y)
    delta = np.square(y_hat - y).sum(axis=1)
    loss = np.mean(np.maximum(delta + true_energy - pred_energy, 0))
    return loss + lamb * l2_norm(params)
コード例 #17
0
 def l2(params):
     emb_params, rnn_params, readout_params = params
     return l2_penalty * jnp.power(optimizers.l2_norm(rnn_params), 2)
コード例 #18
0
 def pretrain_loss(params, batch, l2=0.05):
     _, lr_params = params
     X, y = batch
     y_hat = predict(params, X)
     return -np.mean(np.sum(y * y_hat,
                            axis=1)) + (l2 / 2) * l2_norm(lr_params)**2.
コード例 #19
0
    print(f"Epoch {epoch}, seed={seed}")
    train_loader = tfds.as_numpy(
        train_data.shuffle(n_train, seed=seed).batch(args.batch_size_train))
    X, Y = None, None
    for i, batch in enumerate(train_loader):
        batch = process_data(batch,
                             flatten=args.flatten,
                             centralize_y=args.centralize_y,
                             split='train' if args.data_augment else 'test',
                             seed=seed * args.seed_separator + i,
                             num_classes=args.num_classes)
        X, Y = batch['image'], batch['label']
        params_curr = get_params(state)
        loss_curr, grad_curr = value_and_grad_loss(params_curr, X, Y)
        # monitor gradient norm
        grad_norm = optimizers.l2_norm(grad_curr)
        writer.add_scalar(f'grad_norm/{tb_flag}', grad_norm.item(),
                          global_step)
        if np.isnan(loss_curr):
            sys.exit()
        running_loss += loss_curr
        running_count += 1
        if args.grad_norm_thresh > 0:
            grad_curr = optimizers.clip_grads(grad_curr, args.grad_norm_thresh)
        state = opt_apply(epoch, grad_curr, state)
        global_step += 1
        print(
            f"Step {global_step}, training loss={loss_curr:.4f}, grad norm={grad_norm:.4f}"
        )

        # Evaluate on the test set
コード例 #20
0
ファイル: train.py プロジェクト: ChrisWaites/data-deletion
def projection(tree, max_norm=1.):
    """Clip gradients stored as a tree of arrays to maximum norm `max_norm`."""
    norm = l2_norm(tree)
    normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
    return tree_map(normalize, tree)
コード例 #21
0
 def testUtilityNorm(self):
     x0 = (np.ones(2), (np.ones(3), np.ones(4)))
     norm = optimizers.l2_norm(x0)
     expected = onp.sqrt(onp.sum(onp.ones(2 + 3 + 4)**2))
     self.assertAllClose(norm, expected, check_dtypes=False)
コード例 #22
0
ファイル: train.py プロジェクト: ChrisWaites/data-deletion
    opt_state = opt_init(params)

    temp, rng = random.split(rng)
    batches = data_stream(temp, batch_size, X, y)

    for i in tqdm(range(iterations)):
        opt_state = update(i, opt_state, next(batches))
    fe_params, lr_params = get_params(opt_state)

    print('Accuracy (train): {:.4f}'.format(
        accuracy((fe_params, lr_params), predict, X, y)))
    print('Accuracy (test): {:.4f}'.format(
        accuracy((fe_params, lr_params), predict, X_test, y_test)))

    print(lr_params)
    print('L2 norm: {:.4f}'.format(l2_norm(lr_params)))

    # Extract features
    X_train_proj = onp.asarray(feature_extractor(fe_params, X))
    y_train_proj = onp.asarray(y)
    X_test_proj = onp.asarray(feature_extractor(fe_params, X_test))
    y_test_proj = onp.asarray(y_test)

    n = str(X.shape[0])
    dim = str(X_train_proj.shape[1])
    directory = 'n={}_d={}'.format(n, dim)
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

    # Dump training and eval script
コード例 #23
0
ファイル: train.py プロジェクト: ChrisWaites/data-deletion
 def loss(params, batch, l2=0.05):
     X, y = batch
     y_hat = predict(params, X).reshape(-1)
     return -np.mean(np.log(y * y_hat + (1. - y) *
                            (1. - y_hat))) + (l2 / 2) * l2_norm(
                                params[1])**2.
コード例 #24
0
def pytree_normalized(x):
    return tree_util.tree_map(lambda a: a / l2_norm(x), x)
コード例 #25
0
def loss(W, b):
    logits = predict(W, b, inputs)
    preds = logits - logsumexp(logits, axis=1, keepdims=True)
    loss = -jnp.mean(jnp.sum(preds * targets, axis=1))
    loss += 0.001 * (l2_norm(W) + l2_norm(b))
    return loss
コード例 #26
0
def pytree_relative_error(x, y):
    partial_error = tree_util.tree_multimap(
        lambda a, b: l2_norm(pytree_sub(a, b)) / (l2_norm(a) + 1e-5), x, y)
    return tree_util.tree_reduce(lax.add, partial_error)
コード例 #27
0
 def objective_grad(self, params, bparam):  # TODO: JIT?
     grad_J = grad(self.objective, [0, 1])
     params_grad, bparam_grad = grad_J(params, bparam)
     result = l2_norm(params_grad) + l2_norm(bparam_grad)
     return result
コード例 #28
0
def cross_entropy_loss(params, x, y, lamb=0.001):
    neglogprob = -np.mean(sigmoid_cross_entropy(-apply_mlp(params, x), y))
    return neglogprob + lamb * l2_norm(params)
コード例 #29
0
    def run(self):
        """Runs the continuation strategy.

        A continuation strategy that defines how predictor and corrector components of the algorithm
        interact with the states of the mathematical system.
        """
        self.sw = StateWriter(f"{self.output_file}/version_{self.key_state}.json")

        for i in range(self.continuation_steps):
            print(self._value_wrap.get_record(), self._bparam_wrap.get_record())
            self._state_wrap.counter = i
            self._bparam_wrap.counter = i
            self._value_wrap.counter = i
            self.sw.write(
                [
                    self._state_wrap.get_record(),
                    self._bparam_wrap.get_record(),
                    self._value_wrap.get_record(),
                ]
            )

            concat_states = [
                (self._state_wrap.state, self._bparam_wrap.state),
                (self._prev_state, self._prev_bparam),
                self.prev_secant_direction,
            ]

            predictor = SecantPredictor(
                concat_states=concat_states,
                delta_s=self._delta_s,
                omega=self._omega,
                net_spacing_param=self.hparams["net_spacing_param"],
                net_spacing_bparam=self.hparams["net_spacing_bparam"],
                hparams=self.hparams,
            )
            predictor.prediction_step()
            self.prev_secant_direction = predictor.secant_direction
            self.hparams["sphere_radius"] = (
                0.005 * self.hparams["omega"] * l2_norm(predictor.secant_direction)
            )
            concat_states = [
                predictor.state,
                predictor.bparam,
                predictor.secant_direction,
                predictor.get_secant_concat(),
            ]

            del predictor
            gc.collect()
            corrector = PerturbedCorrecter(
                optimizer=self.opt,
                objective=self.objective,
                dual_objective=self.dual_objective,
                lagrange_multiplier=self._lagrange_multiplier,
                concat_states=concat_states,
                delta_s=self._delta_s,
                ascent_opt=self.ascent_opt,
                key_state=self.key_state,
                compute_min_grad_fn=self.compute_min_grad_fn,
                compute_max_grad_fn=self.compute_max_grad_fn,
                compute_grad_fn=self.compute_grad_fn,
                hparams=self.hparams,
                pred_state=[self._state_wrap.state, self._bparam_wrap.state],
                pred_prev_state=[self._state_wrap.state, self._bparam_wrap.state],
                counter=self.continuation_steps,
            )
            self._prev_state = copy.deepcopy(self._state_wrap.state)
            self._prev_bparam = copy.deepcopy(self._bparam_wrap.state)

            state, bparam, quality = corrector.correction_step()
            value = self.value_func(state, bparam)
            print(
                "How far ....", pytree_relative_error(self._bparam_wrap.state, bparam)
            )
            self._state_wrap.state = state
            self._bparam_wrap.state = bparam
            self._value_wrap.state = value
            del corrector
            gc.collect()
コード例 #30
0
def losses(params, hps, key, x_bxt, class_id_b, kl_scale, keep_rate):
  """Compute the training loss of the LFADS autoencoder.

  Args:
    params: a dictionary of LFADS parameters
    hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_bxt: np array of input with leading dims being batch and time
    class_id_b: class ids, np array of integers for what classes x_bxt are in
    kl_scale: scale on KL
    keep_rate: dropout keep rate

  Returns:
    a dictionary of all losses, including the key 'total' used for optimization
  """

  B = hps['batch_size']
  I = hps['ii_dim']
  T = hps['ntimesteps']

  keys = random.split(key, 2)
  keys_b = random.split(keys[0], B)
  use_mean = False
  lfads = batch_forward_pass(params, hps, keys_b, x_bxt, class_id_b, keep_rate,
                             use_mean)

  post_mean_bxz = batch_compose_latent(hps, lfads['ib_post_mean'],
                                       lfads['ic_post_mean'],
                                       lfads['ii_post_mean_t'])
  post_logvar_bxz = batch_compose_latent(hps, lfads['ib_post_logvar'],
                                         lfads['ic_post_logvar'],
                                         lfads['ii_post_logvar_t'])

  prior_mean_bxz = params['prior']['means'][class_id_b]
  prior_logvar_bxz = params['prior']['logvars'][class_id_b]

  # Sum over time and state dims, average over batch.
  # KL - g0
  kl_loss_bxz = \
      dists.batch_kl_gauss_b_gauss_b(post_mean_bxz, post_logvar_bxz,
                                     prior_mean_bxz, prior_logvar_bxz,
                                     hps['var_min'])

  kl_loss_prescale = np.mean(np.sum(kl_loss_bxz, axis=1))
  kl_loss = kl_scale * kl_loss_prescale

  # Log-likelihood of data given latents.
  if hps['train_on'] == 'spike_counts':
    spikes = x_bxt
    log_p_xgz = (np.sum(dists.poisson_log_likelihood(spikes,
                                                     lfads['lograte_t']))
                 / float(B))
  elif hps['train_on'] == 'continuous':
    continuous = x_bxt
    mean = lfads['lograte_t']
    logvar = np.zeros(mean.shape)  # TODO(sussillo): hyperparameter
    log_p_xgz = (np.sum(dists.diag_gaussian_log_likelihood(continuous,
                                                           mean, logvar))
                 / float(B))
  else:
    raise NotImplementedError

  # Implements the idea that inputs to the generator should be minimal, in the
  # sense of attempting to interpret the inferred inputs as actual inputs to a
  # recurrent system, under an assumption of minimal intervention to that
  # system.
  _, _, ii_post_mean_bxtxi = batch_decompose_latent(hps, post_mean_bxz)
  _, _, ii_prior_mean_bxtxi = batch_decompose_latent(hps, prior_mean_bxz)
  ii_l2_loss = hps['ii_l2_reg'] * (np.sum(ii_prior_mean_bxtxi**2) / float(B) +
                                   np.sum(ii_post_mean_bxtxi**2) / float(B))

  # Implements the idea that the average inferred input should be zero.
  if ii_post_mean_bxtxi.shape[2] > 0:
    ii_tavg_loss = (hps['ii_tavg_reg'] *
                    (np.mean(np.mean(ii_prior_mean_bxtxi, axis=1)**2) +
                     np.mean(np.mean(ii_post_mean_bxtxi, axis=1)**2)))
  else:
    ii_tavg_loss = 0.0

  # L2 - TODO(sussillo): exclusion method is not general to pytrees
  l2reg = hps['l2reg']
  l2_ignore = ['prior']
  l2_params = [p for k, p in params.items() if k not in l2_ignore]
  l2_loss = l2reg * optimizers.l2_norm(l2_params)**2

  loss = -log_p_xgz + kl_loss + l2_loss + ii_l2_loss + ii_tavg_loss
  all_losses = {'total': loss, 'nlog_p_xgz': -log_p_xgz,
                'kl': kl_loss, 'kl_prescale': kl_loss_prescale,
                'ii_l2': ii_l2_loss, 'ii_tavg': ii_tavg_loss, 'l2': l2_loss}

  return all_losses