Beispiel #1
0
def _train_tnet_jointly(X,
                        y,
                        w,
                        binary_y: bool = False,
                        n_layers_out: int = DEFAULT_LAYERS_OUT,
                        n_units_out: int = DEFAULT_UNITS_OUT,
                        n_layers_r: int = DEFAULT_LAYERS_R,
                        n_units_r: int = DEFAULT_UNITS_R,
                        penalty_l2: float = DEFAULT_PENALTY_L2,
                        step_size: float = DEFAULT_STEP_SIZE,
                        n_iter: int = DEFAULT_N_ITER,
                        batch_size: int = DEFAULT_BATCH_SIZE,
                        val_split_prop: float = DEFAULT_VAL_SPLIT,
                        early_stopping: bool = True,
                        patience: int = DEFAULT_PATIENCE,
                        n_iter_min: int = DEFAULT_N_ITER_MIN,
                        verbose: int = 1,
                        n_iter_print: int = DEFAULT_N_ITER_PRINT,
                        seed: int = DEFAULT_SEED,
                        return_val_loss: bool = False,
                        same_init: bool = True,
                        penalty_diff: float = DEFAULT_PENALTY_L2,
                        nonlin: str = DEFAULT_NONLIN,
                        avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    # input check
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)

    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(
        X, y, w, val_split_prop=val_split_prop, seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get output head functions (both heads share same structure)
    init_fun_head, predict_fun_head = OutputHead(n_layers_out=n_layers_out,
                                                 n_units_out=n_units_out,
                                                 binary_y=binary_y,
                                                 n_layers_r=n_layers_r,
                                                 n_units_r=n_units_r,
                                                 nonlin=nonlin)

    # Define loss functions
    # loss functions for the head
    if not binary_y:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head(params, inputs)
            return jnp.sum(weights * ((preds - targets)**2))
    else:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head(params, inputs)
            return -jnp.sum(weights * (targets * jnp.log(preds) +
                                       (1 - targets) * jnp.log(1 - preds)))

    @jit
    def loss_tnet(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1]
        # batch: (X, y, w)
        X, y, w = batch

        # pass down to two heads
        loss_0 = loss_head(params[0], (X, y, 1 - w))
        loss_1 = loss_head(params[1], (X, y, w))

        # regularization
        weightsq_head = heads_l2_penalty(params[0], params[1],
                                         n_layers_r + n_layers_out, True,
                                         penalty_l2, penalty_diff)
        if not avg_objective:
            return loss_0 + loss_1 + 0.5 * (weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_diff):
        # updating function
        params = get_params(state)
        return opt_update(
            i,
            grad(loss_tnet)(params, batch, penalty_l2, penalty_diff), state)

    # initialise states
    if same_init:
        _, init_head = init_fun_head(rng_key, input_shape)
        init_params = [init_head, init_head]
    else:
        rng_key, rng_key_2 = random.split(rng_key)
        _, init_head_0 = init_fun_head(rng_key, input_shape)
        _, init_head_1 = init_fun_head(rng_key_2, input_shape)
        init_params = [init_head_0, init_head_1]

    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2, penalty_diff)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_tnet(params_curr, (X_val, y_val, w_val), penalty_l2,
                               penalty_diff)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, predict_fun_head
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_tnet(params_curr, (X_val, y_val, w_val), 0,
                                        0)
                    return params_curr, predict_fun_head, l_final

                return params_curr, predict_fun_head

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_tnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
        return trained_params, predict_fun_head, l_final

    return trained_params, predict_fun_head
def train_snet2(X,
                y,
                w,
                binary_y: bool = False,
                n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER,
                batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True,
                patience: int = DEFAULT_PATIENCE,
                n_iter_min: int = DEFAULT_N_ITER_MIN,
                verbose: int = 1,
                n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED,
                return_val_loss: bool = False,
                reg_diff: bool = False,
                penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN,
                avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].
    """
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(
        X, y, w, val_split_prop=val_split_prop, seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layer
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r,
                                                n_units=n_units_r,
                                                nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=binary_y,
        nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=True,
        nonlin=nonlin)

    def init_fun_snet2(rng, input_shape):
        # chain together the layers
        # param should look like [repr, po_0, po_1, prop]
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)

        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)

        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng,
                                                     input_shape_repr)
        return input_shape, [param_repr, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets)**2))
    else:

        def loss_head(params, batch):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return -jnp.sum(weights * (targets * jnp.log(preds) +
                                       (1 - targets) * jnp.log(1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)

        return -jnp.sum(targets * jnp.log(preds) +
                        (1 - targets) * jnp.log(1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet2(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1, head_prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps = predict_fun_repr(params[0], X)

        # pass down to heads
        loss_0 = loss_head(params[1], (reps, y, 1 - w))
        loss_1 = loss_head(params[2], (reps, y, w))

        # pass down to propensity head
        loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2)
        weightsq_prop = sum([
            jnp.sum(params[3][i][0]**2)
            for i in range(0, 2 * n_layers_out + 1, 2)
        ])

        weightsq_body = sum(
            [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)])
        weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + 0.5 * (
                penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_diff):
        # updating function
        params = get_params(state)
        return opt_update(
            i,
            grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state)

    # initialise states
    _, init_params = init_fun_snet2(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2, penalty_diff)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet2(params_curr, (X_val, y_val, w_val), penalty_l2,
                                penalty_diff)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0,
                                         0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0,
                             0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po,
                            predict_fun_head_prop)
Beispiel #3
0
def train_r_stage2(X,
                   y_ortho,
                   w_ortho,
                   n_layers_out: int = DEFAULT_LAYERS_OUT,
                   n_units_out: int = DEFAULT_UNITS_OUT,
                   n_layers_r: int = 0,
                   n_units_r: int = DEFAULT_UNITS_R,
                   penalty_l2: float = DEFAULT_PENALTY_L2,
                   step_size: float = DEFAULT_STEP_SIZE,
                   n_iter: int = DEFAULT_N_ITER,
                   batch_size: int = DEFAULT_BATCH_SIZE,
                   val_split_prop: float = DEFAULT_VAL_SPLIT,
                   early_stopping: bool = True,
                   patience: int = DEFAULT_PATIENCE,
                   n_iter_min: int = DEFAULT_N_ITER_MIN,
                   verbose: int = 1,
                   n_iter_print: int = DEFAULT_N_ITER_PRINT,
                   seed: int = DEFAULT_SEED,
                   return_val_loss: bool = False,
                   nonlin: str = DEFAULT_NONLIN,
                   avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    # function to train a single output head
    # input check
    y_ortho, w_ortho = check_shape_1d_data(y_ortho), check_shape_1d_data(
        w_ortho)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    # get validation split (can be none)
    X, y_ortho, w_ortho, X_val, y_val, w_val, val_string = make_val_split(
        X,
        y_ortho,
        w_ortho,
        val_split_prop=val_split_prop,
        seed=seed,
        stratify_w=False)
    n = X.shape[0]  # could be different from before due to split

    # get output head
    init_fun, predict_fun = OutputHead(n_layers_out=n_layers_out,
                                       n_units_out=n_units_out,
                                       n_layers_r=n_layers_r,
                                       n_units_r=n_units_r,
                                       nonlin=nonlin)

    # define loss and grad
    @jit
    def loss(params, batch, penalty):
        # mse loss function
        inputs, ortho_targets, ortho_treats = batch
        preds = predict_fun(params, inputs)
        weightsq = sum([
            jnp.sum(params[i][0]**2)
            for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)
        ])
        if not avg_objective:
            return jnp.sum((ortho_targets - ortho_treats * preds) ** 2) + \
               0.5 * penalty * weightsq
        else:
            return jnp.average((ortho_targets - ortho_treats * preds) ** 2) + \
               0.5 * penalty * weightsq

    # set optimization routine
    # set optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    # set update function
    @jit
    def update(i, state, batch, penalty):
        params = get_params(state)
        g_params = grad(loss)(params, batch, penalty)
        # g_params = optimizers.clip_grads(g_params, 1.0)
        return opt_update(i, g_params, state)

    # initialise states
    _, init_params = init_fun(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y_ortho[idx_next, :], w_ortho[
                idx_next, :]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss(params_curr, (X_val, y_val, w_val), penalty_l2)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss: {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
            else:
                p_curr = p_curr + 1

            if p_curr > patience:
                trained_params = get_params(opt_state)

                if return_val_loss:
                    # return loss without penalty
                    l_final = loss(trained_params, (X_val, y_val, w_val), 0)
                    return trained_params, predict_fun, l_final

                return trained_params, predict_fun

    # get final parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss(trained_params, (X_val, y_val, w_val), 0)
        return trained_params, predict_fun, l_final

    return trained_params, predict_fun
def train_snet3(X, y, w, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R_BIG_S3,
                n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC,
                penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN,
                patience: int = DEFAULT_PATIENCE,
                verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED, return_val_loss: bool = False,
                reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet-3, based on the decompostion used in Hassanpour and Greiner (2020)
    """
    # function to train a net with 3 representations
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(X, y, w,
                                                              val_split_prop=val_split_prop,
                                                              seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layers
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r, n_units=n_units_r,
                                                nonlin=nonlin)
    init_fun_repr_small, predict_fun_repr_small = ReprBlock(n_layers=n_layers_r,
                                                            n_units=n_units_r_small, nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(n_layers_out=n_layers_out,
                                                       n_units_out=n_units_out,
                                                       binary_y=binary_y, nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(n_layers_out=n_layers_out,
                                                           n_units_out=n_units_out, binary_y=True,
                                                           nonlin=nonlin)

    def init_fun_snet3(rng, input_shape):
        # chain together the layers
        # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop]
        # initialise representation layers
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        input_shape_repr_small, param_repr_o = init_fun_repr_small(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        _, param_repr_w = init_fun_repr_small(layer_rng, input_shape)

        # each head gets two representations
        input_shape_repr = input_shape_repr[:-1] + (input_shape_repr[-1] + input_shape_repr_small[
            -1],)

        # initialise output heads
        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)
        return input_shape, [param_repr_c, param_repr_o, param_repr_w, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:
        def loss_head(params, batch, penalty):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets) ** 2))
    else:
        def loss_head(params, batch, penalty):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return - jnp.sum(weights * (targets * jnp.log(preds) +
                                        (1 - targets) * jnp.log(
                        1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)
        return - jnp.sum(targets * jnp.log(preds) +
                         (1 - targets) * jnp.log(
            1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet3(params, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps_c = predict_fun_repr(params[0], X)
        reps_o = predict_fun_repr_small(params[1], X)
        reps_w = predict_fun_repr_small(params[2], X)

        # concatenate
        reps_po = _concatenate_representations((reps_c, reps_o))
        reps_prop = _concatenate_representations((reps_c, reps_w))

        # pass down to heads
        loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2)
        loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2)

        # pass down to propensity head
        loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2)
        weightsq_prop = sum([jnp.sum(params[5][i][0] ** 2) for i in
                             range(0, 2 * n_layers_out + 1, 2)])

        # which variable has impact on which representation
        col_c = _get_absolute_rowsums(params[0][0][0])
        col_o = _get_absolute_rowsums(params[1][0][0])
        col_w = _get_absolute_rowsums(params[2][0][0])
        loss_o = penalty_orthogonal * (
            jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o))

        # is rep_o balanced between groups?
        loss_disc = penalty_disc * mmd2_lin(reps_o, w)

        # weight decay on representations
        weightsq_body = sum([sum([jnp.sum(params[j][i][0] ** 2) for i in
                                  range(0, 2 * n_layers_r, 2)]) for j in range(3)])
        weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out, reg_diff,
                                         penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + loss_o + loss_disc + \
               0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_o + loss_disc + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # updating function
        params = get_params(state)
        return opt_update(i, grad(loss_snet3)(
            params, batch, penalty_l2, penalty_orthogonal, penalty_disc),
                          state)

    # initialise states
    _, init_params = init_fun_snet3(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2,
                               penalty_orthogonal,
                               penalty_disc)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet3(params_curr, (X_val, y_val, w_val),
                                penalty_l2, penalty_orthogonal, penalty_disc)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(i,
                                                         val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0,
                                         0, 0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop)