def train_snet2(X,
                y,
                w,
                binary_y: bool = False,
                n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER,
                batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True,
                patience: int = DEFAULT_PATIENCE,
                n_iter_min: int = DEFAULT_N_ITER_MIN,
                verbose: int = 1,
                n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED,
                return_val_loss: bool = False,
                reg_diff: bool = False,
                penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN,
                avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].
    """
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(
        X, y, w, val_split_prop=val_split_prop, seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layer
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r,
                                                n_units=n_units_r,
                                                nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=binary_y,
        nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=True,
        nonlin=nonlin)

    def init_fun_snet2(rng, input_shape):
        # chain together the layers
        # param should look like [repr, po_0, po_1, prop]
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)

        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)

        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng,
                                                     input_shape_repr)
        return input_shape, [param_repr, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets)**2))
    else:

        def loss_head(params, batch):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return -jnp.sum(weights * (targets * jnp.log(preds) +
                                       (1 - targets) * jnp.log(1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)

        return -jnp.sum(targets * jnp.log(preds) +
                        (1 - targets) * jnp.log(1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet2(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1, head_prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps = predict_fun_repr(params[0], X)

        # pass down to heads
        loss_0 = loss_head(params[1], (reps, y, 1 - w))
        loss_1 = loss_head(params[2], (reps, y, w))

        # pass down to propensity head
        loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2)
        weightsq_prop = sum([
            jnp.sum(params[3][i][0]**2)
            for i in range(0, 2 * n_layers_out + 1, 2)
        ])

        weightsq_body = sum(
            [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)])
        weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + 0.5 * (
                penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_diff):
        # updating function
        params = get_params(state)
        return opt_update(
            i,
            grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state)

    # initialise states
    _, init_params = init_fun_snet2(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2, penalty_diff)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet2(params_curr, (X_val, y_val, w_val), penalty_l2,
                                penalty_diff)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0,
                                         0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0,
                             0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po,
                            predict_fun_head_prop)
def train_snet3(X, y, w, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R_BIG_S3,
                n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC,
                penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN,
                patience: int = DEFAULT_PATIENCE,
                verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED, return_val_loss: bool = False,
                reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet-3, based on the decompostion used in Hassanpour and Greiner (2020)
    """
    # function to train a net with 3 representations
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(X, y, w,
                                                              val_split_prop=val_split_prop,
                                                              seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layers
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r, n_units=n_units_r,
                                                nonlin=nonlin)
    init_fun_repr_small, predict_fun_repr_small = ReprBlock(n_layers=n_layers_r,
                                                            n_units=n_units_r_small, nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(n_layers_out=n_layers_out,
                                                       n_units_out=n_units_out,
                                                       binary_y=binary_y, nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(n_layers_out=n_layers_out,
                                                           n_units_out=n_units_out, binary_y=True,
                                                           nonlin=nonlin)

    def init_fun_snet3(rng, input_shape):
        # chain together the layers
        # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop]
        # initialise representation layers
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        input_shape_repr_small, param_repr_o = init_fun_repr_small(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        _, param_repr_w = init_fun_repr_small(layer_rng, input_shape)

        # each head gets two representations
        input_shape_repr = input_shape_repr[:-1] + (input_shape_repr[-1] + input_shape_repr_small[
            -1],)

        # initialise output heads
        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)
        return input_shape, [param_repr_c, param_repr_o, param_repr_w, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:
        def loss_head(params, batch, penalty):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets) ** 2))
    else:
        def loss_head(params, batch, penalty):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return - jnp.sum(weights * (targets * jnp.log(preds) +
                                        (1 - targets) * jnp.log(
                        1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)
        return - jnp.sum(targets * jnp.log(preds) +
                         (1 - targets) * jnp.log(
            1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet3(params, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps_c = predict_fun_repr(params[0], X)
        reps_o = predict_fun_repr_small(params[1], X)
        reps_w = predict_fun_repr_small(params[2], X)

        # concatenate
        reps_po = _concatenate_representations((reps_c, reps_o))
        reps_prop = _concatenate_representations((reps_c, reps_w))

        # pass down to heads
        loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2)
        loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2)

        # pass down to propensity head
        loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2)
        weightsq_prop = sum([jnp.sum(params[5][i][0] ** 2) for i in
                             range(0, 2 * n_layers_out + 1, 2)])

        # which variable has impact on which representation
        col_c = _get_absolute_rowsums(params[0][0][0])
        col_o = _get_absolute_rowsums(params[1][0][0])
        col_w = _get_absolute_rowsums(params[2][0][0])
        loss_o = penalty_orthogonal * (
            jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o))

        # is rep_o balanced between groups?
        loss_disc = penalty_disc * mmd2_lin(reps_o, w)

        # weight decay on representations
        weightsq_body = sum([sum([jnp.sum(params[j][i][0] ** 2) for i in
                                  range(0, 2 * n_layers_r, 2)]) for j in range(3)])
        weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out, reg_diff,
                                         penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + loss_o + loss_disc + \
               0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_o + loss_disc + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # updating function
        params = get_params(state)
        return opt_update(i, grad(loss_snet3)(
            params, batch, penalty_l2, penalty_orthogonal, penalty_disc),
                          state)

    # initialise states
    _, init_params = init_fun_snet3(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2,
                               penalty_orthogonal,
                               penalty_disc)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet3(params_curr, (X_val, y_val, w_val),
                                penalty_l2, penalty_orthogonal, penalty_disc)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(i,
                                                         val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0,
                                         0, 0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop)