예제 #1
0
def train_r_stage2(X,
                   y_ortho,
                   w_ortho,
                   n_layers_out: int = DEFAULT_LAYERS_OUT,
                   n_units_out: int = DEFAULT_UNITS_OUT,
                   n_layers_r: int = 0,
                   n_units_r: int = DEFAULT_UNITS_R,
                   penalty_l2: float = DEFAULT_PENALTY_L2,
                   step_size: float = DEFAULT_STEP_SIZE,
                   n_iter: int = DEFAULT_N_ITER,
                   batch_size: int = DEFAULT_BATCH_SIZE,
                   val_split_prop: float = DEFAULT_VAL_SPLIT,
                   early_stopping: bool = True,
                   patience: int = DEFAULT_PATIENCE,
                   n_iter_min: int = DEFAULT_N_ITER_MIN,
                   verbose: int = 1,
                   n_iter_print: int = DEFAULT_N_ITER_PRINT,
                   seed: int = DEFAULT_SEED,
                   return_val_loss: bool = False,
                   nonlin: str = DEFAULT_NONLIN,
                   avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    # function to train a single output head
    # input check
    y_ortho, w_ortho = check_shape_1d_data(y_ortho), check_shape_1d_data(
        w_ortho)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    # get validation split (can be none)
    X, y_ortho, w_ortho, X_val, y_val, w_val, val_string = make_val_split(
        X,
        y_ortho,
        w_ortho,
        val_split_prop=val_split_prop,
        seed=seed,
        stratify_w=False)
    n = X.shape[0]  # could be different from before due to split

    # get output head
    init_fun, predict_fun = OutputHead(n_layers_out=n_layers_out,
                                       n_units_out=n_units_out,
                                       n_layers_r=n_layers_r,
                                       n_units_r=n_units_r,
                                       nonlin=nonlin)

    # define loss and grad
    @jit
    def loss(params, batch, penalty):
        # mse loss function
        inputs, ortho_targets, ortho_treats = batch
        preds = predict_fun(params, inputs)
        weightsq = sum([
            jnp.sum(params[i][0]**2)
            for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)
        ])
        if not avg_objective:
            return jnp.sum((ortho_targets - ortho_treats * preds) ** 2) + \
               0.5 * penalty * weightsq
        else:
            return jnp.average((ortho_targets - ortho_treats * preds) ** 2) + \
               0.5 * penalty * weightsq

    # set optimization routine
    # set optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    # set update function
    @jit
    def update(i, state, batch, penalty):
        params = get_params(state)
        g_params = grad(loss)(params, batch, penalty)
        # g_params = optimizers.clip_grads(g_params, 1.0)
        return opt_update(i, g_params, state)

    # initialise states
    _, init_params = init_fun(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y_ortho[idx_next, :], w_ortho[
                idx_next, :]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss(params_curr, (X_val, y_val, w_val), penalty_l2)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss: {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
            else:
                p_curr = p_curr + 1

            if p_curr > patience:
                trained_params = get_params(opt_state)

                if return_val_loss:
                    # return loss without penalty
                    l_final = loss(trained_params, (X_val, y_val, w_val), 0)
                    return trained_params, predict_fun, l_final

                return trained_params, predict_fun

    # get final parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss(trained_params, (X_val, y_val, w_val), 0)
        return trained_params, predict_fun, l_final

    return trained_params, predict_fun
예제 #2
0
def _train_tnet_jointly(X,
                        y,
                        w,
                        binary_y: bool = False,
                        n_layers_out: int = DEFAULT_LAYERS_OUT,
                        n_units_out: int = DEFAULT_UNITS_OUT,
                        n_layers_r: int = DEFAULT_LAYERS_R,
                        n_units_r: int = DEFAULT_UNITS_R,
                        penalty_l2: float = DEFAULT_PENALTY_L2,
                        step_size: float = DEFAULT_STEP_SIZE,
                        n_iter: int = DEFAULT_N_ITER,
                        batch_size: int = DEFAULT_BATCH_SIZE,
                        val_split_prop: float = DEFAULT_VAL_SPLIT,
                        early_stopping: bool = True,
                        patience: int = DEFAULT_PATIENCE,
                        n_iter_min: int = DEFAULT_N_ITER_MIN,
                        verbose: int = 1,
                        n_iter_print: int = DEFAULT_N_ITER_PRINT,
                        seed: int = DEFAULT_SEED,
                        return_val_loss: bool = False,
                        same_init: bool = True,
                        penalty_diff: float = DEFAULT_PENALTY_L2,
                        nonlin: str = DEFAULT_NONLIN,
                        avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    # input check
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)

    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(
        X, y, w, val_split_prop=val_split_prop, seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get output head functions (both heads share same structure)
    init_fun_head, predict_fun_head = OutputHead(n_layers_out=n_layers_out,
                                                 n_units_out=n_units_out,
                                                 binary_y=binary_y,
                                                 n_layers_r=n_layers_r,
                                                 n_units_r=n_units_r,
                                                 nonlin=nonlin)

    # Define loss functions
    # loss functions for the head
    if not binary_y:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head(params, inputs)
            return jnp.sum(weights * ((preds - targets)**2))
    else:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head(params, inputs)
            return -jnp.sum(weights * (targets * jnp.log(preds) +
                                       (1 - targets) * jnp.log(1 - preds)))

    @jit
    def loss_tnet(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1]
        # batch: (X, y, w)
        X, y, w = batch

        # pass down to two heads
        loss_0 = loss_head(params[0], (X, y, 1 - w))
        loss_1 = loss_head(params[1], (X, y, w))

        # regularization
        weightsq_head = heads_l2_penalty(params[0], params[1],
                                         n_layers_r + n_layers_out, True,
                                         penalty_l2, penalty_diff)
        if not avg_objective:
            return loss_0 + loss_1 + 0.5 * (weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_diff):
        # updating function
        params = get_params(state)
        return opt_update(
            i,
            grad(loss_tnet)(params, batch, penalty_l2, penalty_diff), state)

    # initialise states
    if same_init:
        _, init_head = init_fun_head(rng_key, input_shape)
        init_params = [init_head, init_head]
    else:
        rng_key, rng_key_2 = random.split(rng_key)
        _, init_head_0 = init_fun_head(rng_key, input_shape)
        _, init_head_1 = init_fun_head(rng_key_2, input_shape)
        init_params = [init_head_0, init_head_1]

    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2, penalty_diff)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_tnet(params_curr, (X_val, y_val, w_val), penalty_l2,
                               penalty_diff)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, predict_fun_head
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_tnet(params_curr, (X_val, y_val, w_val), 0,
                                        0)
                    return params_curr, predict_fun_head, l_final

                return params_curr, predict_fun_head

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_tnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
        return trained_params, predict_fun_head, l_final

    return trained_params, predict_fun_head
예제 #3
0
def train_r_net(X,
                y,
                w,
                p=None,
                second_stage_strategy: str = R_STRATEGY_NAME,
                data_split: bool = False,
                cross_fit: bool = False,
                n_cf_folds: int = DEFAULT_CF_FOLDS,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_layers_r: int = DEFAULT_LAYERS_R,
                n_layers_r_t: int = DEFAULT_LAYERS_R_T,
                n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
                n_units_out: int = DEFAULT_UNITS_OUT,
                n_units_r: int = DEFAULT_UNITS_R,
                n_units_out_t: int = DEFAULT_UNITS_OUT_T,
                n_units_r_t: int = DEFAULT_UNITS_R_T,
                penalty_l2: float = DEFAULT_PENALTY_L2,
                penalty_l2_t: float = DEFAULT_PENALTY_L2,
                step_size: float = DEFAULT_STEP_SIZE,
                step_size_t: float = DEFAULT_STEP_SIZE_T,
                n_iter: int = DEFAULT_N_ITER,
                batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True,
                patience: int = DEFAULT_PATIENCE,
                n_iter_min: int = DEFAULT_N_ITER_MIN,
                verbose: int = 1,
                n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED,
                return_val_loss: bool = False,
                nonlin: str = DEFAULT_NONLIN):
    # get shape of data
    n, d = X.shape

    if p is not None:
        p = check_shape_1d_data(p)

    # split data as wanted
    if not cross_fit:
        if not data_split:
            if verbose > 0:
                print('Training first stage with all data (no data splitting)')
            # use all data for both
            fit_mask = onp.ones(n, dtype=bool)
            pred_mask = onp.ones(n, dtype=bool)
        else:
            if verbose > 0:
                print(
                    'Training first stage with half of the data (data splitting)'
                )
            # split data in half
            fit_idx = onp.random.choice(n, int(onp.round(n / 2)))
            fit_mask = onp.zeros(n, dtype=bool)

            fit_mask[fit_idx] = 1
            pred_mask = ~fit_mask

        mu_hat, pi_hat = _train_and_predict_r_stage1(
            X,
            y,
            w,
            fit_mask,
            pred_mask,
            n_layers_out=n_layers_out,
            n_layers_r=n_layers_r,
            n_units_out=n_units_out,
            n_units_r=n_units_r,
            penalty_l2=penalty_l2,
            step_size=step_size,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            verbose=verbose,
            n_iter_print=n_iter_print,
            seed=seed,
            nonlin=nonlin)
        if data_split:
            # keep only prediction data
            X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :]

            if p is not None:
                p = p[pred_mask, :]

    else:
        if verbose > 0:
            print('Training first stage in {} folds (cross-fitting)'.format(
                n_cf_folds))
        # do cross fitting
        mu_hat, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1))
        splitter = StratifiedKFold(n_splits=n_cf_folds,
                                   shuffle=True,
                                   random_state=seed)

        fold_count = 1
        for train_idx, test_idx in splitter.split(X, w):

            if verbose > 0:
                print('Training fold {}.'.format(fold_count))
            fold_count = fold_count + 1

            pred_mask = onp.zeros(n, dtype=bool)
            pred_mask[test_idx] = 1
            fit_mask = ~pred_mask

            mu_hat[pred_mask], pi_hat[pred_mask] = \
                _train_and_predict_r_stage1(X, y, w, fit_mask, pred_mask,
                                            n_layers_out=n_layers_out,
                                            n_layers_r=n_layers_r,
                                            n_units_out=n_units_out,
                                            n_units_r=n_units_r,
                                            penalty_l2=penalty_l2,
                                            step_size=step_size,
                                            n_iter=n_iter,
                                            batch_size=batch_size,
                                            val_split_prop=val_split_prop,
                                            early_stopping=early_stopping,
                                            patience=patience,
                                            n_iter_min=n_iter_min,
                                            verbose=verbose,
                                            n_iter_print=n_iter_print,
                                            seed=seed, nonlin=nonlin)

    if verbose > 0:
        print('Training second stage.')

    if p is not None:
        # use known propensity score
        p = check_shape_1d_data(p)
        pi_hat = p

    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    w_ortho = w - pi_hat
    y_ortho = y - mu_hat

    if second_stage_strategy == R_STRATEGY_NAME:
        return train_r_stage2(X,
                              y_ortho,
                              w_ortho,
                              n_layers_out=n_layers_out_t,
                              n_units_out=n_units_out_t,
                              n_layers_r=n_layers_r_t,
                              n_units_r=n_units_r_t,
                              penalty_l2=penalty_l2_t,
                              step_size=step_size_t,
                              n_iter=n_iter,
                              batch_size=batch_size,
                              val_split_prop=val_split_prop,
                              early_stopping=early_stopping,
                              patience=patience,
                              n_iter_min=n_iter_min,
                              verbose=verbose,
                              n_iter_print=n_iter_print,
                              seed=seed,
                              return_val_loss=return_val_loss,
                              nonlin=nonlin)
    elif second_stage_strategy == U_STRATEGY_NAME:
        return train_output_net_only(X,
                                     y_ortho / w_ortho,
                                     n_layers_out=n_layers_out_t,
                                     n_units_out=n_units_out_t,
                                     n_layers_r=n_layers_r_t,
                                     n_units_r=n_units_r_t,
                                     penalty_l2=penalty_l2_t,
                                     step_size=step_size_t,
                                     n_iter=n_iter,
                                     batch_size=batch_size,
                                     val_split_prop=val_split_prop,
                                     early_stopping=early_stopping,
                                     patience=patience,
                                     n_iter_min=n_iter_min,
                                     verbose=verbose,
                                     n_iter_print=n_iter_print,
                                     seed=seed,
                                     return_val_loss=return_val_loss,
                                     nonlin=nonlin)
    else:
        raise ValueError('R-learner only supports strategies R and U.')
예제 #4
0
def train_snet2(X,
                y,
                w,
                binary_y: bool = False,
                n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER,
                batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True,
                patience: int = DEFAULT_PATIENCE,
                n_iter_min: int = DEFAULT_N_ITER_MIN,
                verbose: int = 1,
                n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED,
                return_val_loss: bool = False,
                reg_diff: bool = False,
                penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN,
                avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].
    """
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(
        X, y, w, val_split_prop=val_split_prop, seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layer
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r,
                                                n_units=n_units_r,
                                                nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=binary_y,
        nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(
        n_layers_out=n_layers_out,
        n_units_out=n_units_out,
        binary_y=True,
        nonlin=nonlin)

    def init_fun_snet2(rng, input_shape):
        # chain together the layers
        # param should look like [repr, po_0, po_1, prop]
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)

        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng,
                                                    input_shape_repr)

        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng,
                                                     input_shape_repr)
        return input_shape, [param_repr, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:

        def loss_head(params, batch):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets)**2))
    else:

        def loss_head(params, batch):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return -jnp.sum(weights * (targets * jnp.log(preds) +
                                       (1 - targets) * jnp.log(1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)

        return -jnp.sum(targets * jnp.log(preds) +
                        (1 - targets) * jnp.log(1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet2(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1, head_prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps = predict_fun_repr(params[0], X)

        # pass down to heads
        loss_0 = loss_head(params[1], (reps, y, 1 - w))
        loss_1 = loss_head(params[2], (reps, y, w))

        # pass down to propensity head
        loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2)
        weightsq_prop = sum([
            jnp.sum(params[3][i][0]**2)
            for i in range(0, 2 * n_layers_out + 1, 2)
        ])

        weightsq_body = sum(
            [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)])
        weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + 0.5 * (
                penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_diff):
        # updating function
        params = get_params(state)
        return opt_update(
            i,
            grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state)

    # initialise states
    _, init_params = init_fun_snet2(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) *
                                                          batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch,
                               penalty_l2, penalty_diff)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet2(params_curr, (X_val, y_val, w_val), penalty_l2,
                                penalty_diff)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(
                i, val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0,
                                         0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0,
                             0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po,
                            predict_fun_head_prop)
예제 #5
0
def train_snet3(X, y, w, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R,
                n_units_r: int = DEFAULT_UNITS_R_BIG_S3,
                n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_units_out: int = DEFAULT_UNITS_OUT,
                penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC,
                penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
                step_size: float = DEFAULT_STEP_SIZE,
                n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN,
                patience: int = DEFAULT_PATIENCE,
                verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED, return_val_loss: bool = False,
                reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2,
                nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
                same_init: bool = False):
    """
    SNet-3, based on the decompostion used in Hassanpour and Greiner (2020)
    """
    # function to train a net with 3 representations
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    d = X.shape[1]
    input_shape = (-1, d)
    rng_key = random.PRNGKey(seed)
    onp.random.seed(seed)  # set seed for data generation via numpy as well

    if not reg_diff:
        penalty_diff = penalty_l2

    # get validation split (can be none)
    X, y, w, X_val, y_val, w_val, val_string = make_val_split(X, y, w,
                                                              val_split_prop=val_split_prop,
                                                              seed=seed)
    n = X.shape[0]  # could be different from before due to split

    # get representation layers
    init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r, n_units=n_units_r,
                                                nonlin=nonlin)
    init_fun_repr_small, predict_fun_repr_small = ReprBlock(n_layers=n_layers_r,
                                                            n_units=n_units_r_small, nonlin=nonlin)

    # get output head functions (output heads share same structure)
    init_fun_head_po, predict_fun_head_po = OutputHead(n_layers_out=n_layers_out,
                                                       n_units_out=n_units_out,
                                                       binary_y=binary_y, nonlin=nonlin)
    # add propensity head
    init_fun_head_prop, predict_fun_head_prop = OutputHead(n_layers_out=n_layers_out,
                                                           n_units_out=n_units_out, binary_y=True,
                                                           nonlin=nonlin)

    def init_fun_snet3(rng, input_shape):
        # chain together the layers
        # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop]
        # initialise representation layers
        rng, layer_rng = random.split(rng)
        input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        input_shape_repr_small, param_repr_o = init_fun_repr_small(layer_rng, input_shape)
        rng, layer_rng = random.split(rng)
        _, param_repr_w = init_fun_repr_small(layer_rng, input_shape)

        # each head gets two representations
        input_shape_repr = input_shape_repr[:-1] + (input_shape_repr[-1] + input_shape_repr_small[
            -1],)

        # initialise output heads
        rng, layer_rng = random.split(rng)
        if same_init:
            # initialise both on same values
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        else:
            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
            rng, layer_rng = random.split(rng)
            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
        rng, layer_rng = random.split(rng)
        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)
        return input_shape, [param_repr_c, param_repr_o, param_repr_w, param_0, param_1, param_prop]

    # Define loss functions
    # loss functions for the head
    if not binary_y:
        def loss_head(params, batch, penalty):
            # mse loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return jnp.sum(weights * ((preds - targets) ** 2))
    else:
        def loss_head(params, batch, penalty):
            # log loss function
            inputs, targets, weights = batch
            preds = predict_fun_head_po(params, inputs)
            return - jnp.sum(weights * (targets * jnp.log(preds) +
                                        (1 - targets) * jnp.log(
                        1 - preds)))

    def loss_head_prop(params, batch, penalty):
        # log loss function for propensities
        inputs, targets = batch
        preds = predict_fun_head_prop(params, inputs)
        return - jnp.sum(targets * jnp.log(preds) +
                         (1 - targets) * jnp.log(
            1 - preds))

    # complete loss function for all parts
    @jit
    def loss_snet3(params, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps_c = predict_fun_repr(params[0], X)
        reps_o = predict_fun_repr_small(params[1], X)
        reps_w = predict_fun_repr_small(params[2], X)

        # concatenate
        reps_po = _concatenate_representations((reps_c, reps_o))
        reps_prop = _concatenate_representations((reps_c, reps_w))

        # pass down to heads
        loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2)
        loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2)

        # pass down to propensity head
        loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2)
        weightsq_prop = sum([jnp.sum(params[5][i][0] ** 2) for i in
                             range(0, 2 * n_layers_out + 1, 2)])

        # which variable has impact on which representation
        col_c = _get_absolute_rowsums(params[0][0][0])
        col_o = _get_absolute_rowsums(params[1][0][0])
        col_w = _get_absolute_rowsums(params[2][0][0])
        loss_o = penalty_orthogonal * (
            jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o))

        # is rep_o balanced between groups?
        loss_disc = penalty_disc * mmd2_lin(reps_o, w)

        # weight decay on representations
        weightsq_body = sum([sum([jnp.sum(params[j][i][0] ** 2) for i in
                                  range(0, 2 * n_layers_r, 2)]) for j in range(3)])
        weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out, reg_diff,
                                         penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + loss_o + loss_disc + \
               0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_o + loss_disc + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)

    # Define optimisation routine
    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)

    @jit
    def update(i, state, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # updating function
        params = get_params(state)
        return opt_update(i, grad(loss_snet3)(
            params, batch, penalty_l2, penalty_orthogonal, penalty_disc),
                          state)

    # initialise states
    _, init_params = init_fun_snet3(rng_key, input_shape)
    opt_state = opt_init(init_params)

    # calculate number of batches per epoch
    batch_size = batch_size if batch_size < n else n
    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
    train_indices = onp.arange(n)

    l_best = LARGE_VAL
    p_curr = 0

    # do training
    for i in range(n_iter):
        # shuffle data for minibatches
        onp.random.shuffle(train_indices)
        for b in range(n_batches):
            idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)]
            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
            opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2,
                               penalty_orthogonal,
                               penalty_disc)

        if (verbose > 0 and i % n_iter_print == 0) or early_stopping:
            params_curr = get_params(opt_state)
            l_curr = loss_snet3(params_curr, (X_val, y_val, w_val),
                                penalty_l2, penalty_orthogonal, penalty_disc)

        if verbose > 0 and i % n_iter_print == 0:
            print("Epoch: {}, current {} loss {}".format(i,
                                                         val_string, l_curr))

        if early_stopping and ((i + 1) * n_batches > n_iter_min):
            # check if loss updated
            if l_curr < l_best:
                l_best = l_curr
                p_curr = 0
                params_best = params_curr
            else:
                if onp.isnan(l_curr):
                    # if diverged, return best
                    return params_best, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop)
                p_curr = p_curr + 1

            if p_curr > patience:
                if return_val_loss:
                    # return loss without penalty
                    l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0,
                                         0, 0)
                    return params_curr, (predict_fun_repr, predict_fun_head_po,
                                         predict_fun_head_prop), l_final

                return params_curr, (predict_fun_repr, predict_fun_head_po,
                                     predict_fun_head_prop)

    # return the parameters
    trained_params = get_params(opt_state)

    if return_val_loss:
        # return loss without penalty
        l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
        return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \
               l_final

    return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop)
예제 #6
0
def train_twostep_net(X, y, w, p=None, first_stage_strategy: str = T_STRATEGY,
                      data_split: bool = False,
                      cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS,
                      transformation: str = AIPW_TRANSFORMATION,
                      binary_y: bool = False,
                      n_layers_out: int = DEFAULT_LAYERS_OUT,
                      n_layers_r: int = DEFAULT_LAYERS_R,
                      n_layers_r_t: int = DEFAULT_LAYERS_R_T,
                      n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
                      n_units_out: int = DEFAULT_UNITS_OUT,
                      n_units_r: int = DEFAULT_UNITS_R,
                      n_units_out_t: int = DEFAULT_UNITS_OUT_T,
                      n_units_r_t: int = DEFAULT_UNITS_R_T,
                      penalty_l2: float = DEFAULT_PENALTY_L2,
                      penalty_l2_t: float = DEFAULT_PENALTY_L2,
                      step_size: float = DEFAULT_STEP_SIZE,
                      step_size_t: float = DEFAULT_STEP_SIZE_T,
                      n_iter: int = DEFAULT_N_ITER,
                      batch_size: int = DEFAULT_BATCH_SIZE,
                      val_split_prop: float = DEFAULT_VAL_SPLIT,
                      early_stopping: bool = True,
                      patience: int = DEFAULT_PATIENCE,
                      n_iter_min: int = DEFAULT_N_ITER_MIN,
                      verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT,
                      seed: int = DEFAULT_SEED, rescale_transformation: bool = False,
                      return_val_loss: bool = False,
                      penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
                      n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S,
                      nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    # get shape of data
    n, d = X.shape

    if p is not None:
        p = check_shape_1d_data(p)

    # get transformation function
    transformation_function = _get_transformation_function(transformation)

    # get strategy name
    if first_stage_strategy not in ALL_STRATEGIES:
        raise ValueError('Parameter first stage should be in '
                         'catenets.models.twostep_nets.ALL_STRATEGIES. '
                         'You passed {}'.format(first_stage_strategy))

    # split data as wanted
    if p is None or transformation is not HT_TRANSFORMATION:
        if not cross_fit:
            if not data_split:
                if verbose > 0:
                    print('Training first stage with all data (no data splitting)')
                # use all data for both
                fit_mask = onp.ones(n, dtype=bool)
                pred_mask = onp.ones(n, dtype=bool)
            else:
                if verbose > 0:
                    print('Training first stage with half of the data (data splitting)')
                # split data in half
                fit_idx = onp.random.choice(n, int(onp.round(n / 2)))
                fit_mask = onp.zeros(n, dtype=bool)

                fit_mask[fit_idx] = 1
                pred_mask = ~ fit_mask

            mu_0, mu_1, pi_hat = _train_and_predict_first_stage(X, y, w, fit_mask, pred_mask,
                                                                first_stage_strategy=first_stage_strategy,
                                                                binary_y=binary_y,
                                                                n_layers_out=n_layers_out,
                                                                n_layers_r=n_layers_r,
                                                                n_units_out=n_units_out,
                                                                n_units_r=n_units_r,
                                                                penalty_l2=penalty_l2,
                                                                step_size=step_size,
                                                                n_iter=n_iter,
                                                                batch_size=batch_size,
                                                                val_split_prop=val_split_prop,
                                                                early_stopping=early_stopping,
                                                                patience=patience,
                                                                n_iter_min=n_iter_min,
                                                                verbose=verbose,
                                                                n_iter_print=n_iter_print,
                                                                seed=seed,
                                                                penalty_orthogonal=penalty_orthogonal,
                                                                n_units_r_small=n_units_r_small,
                                                                nonlin=nonlin,
                                                                avg_objective=avg_objective,
                                                                transformation=transformation)
            if data_split:
                # keep only prediction data
                X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :]

                if p is not None:
                    p = p[pred_mask, :]

        else:
            if verbose > 0:
                print('Training first stage in {} folds (cross-fitting)'.format(n_cf_folds))
            # do cross fitting
            mu_0, mu_1, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)), onp.zeros((n, 1))
            splitter = StratifiedKFold(n_splits=n_cf_folds, shuffle=True,
                                       random_state=seed)

            fold_count = 1
            for train_idx, test_idx in splitter.split(X, w):

                if verbose > 0:
                    print('Training fold {}.'.format(fold_count))
                fold_count = fold_count + 1

                pred_mask = onp.zeros(n, dtype=bool)
                pred_mask[test_idx] = 1
                fit_mask = ~ pred_mask

                mu_0[pred_mask], mu_1[pred_mask], pi_hat[pred_mask] = \
                    _train_and_predict_first_stage(X, y, w, fit_mask, pred_mask,
                                                   first_stage_strategy=first_stage_strategy,
                                                   binary_y=binary_y,
                                                   n_layers_out=n_layers_out,
                                                   n_layers_r=n_layers_r,
                                                   n_units_out=n_units_out,
                                                   n_units_r=n_units_r,
                                                   penalty_l2=penalty_l2,
                                                   step_size=step_size,
                                                   n_iter=n_iter,
                                                   batch_size=batch_size,
                                                   val_split_prop=val_split_prop,
                                                   early_stopping=early_stopping,
                                                   patience=patience,
                                                   n_iter_min=n_iter_min,
                                                   verbose=verbose,
                                                   n_iter_print=n_iter_print,
                                                   seed=seed,
                                                   penalty_orthogonal=penalty_orthogonal,
                                                   n_units_r_small=n_units_r_small,
                                                   nonlin=nonlin, avg_objective=avg_objective,
                                                   transformation=transformation)

    if verbose > 0:
        print('Training second stage.')

    if p is not None:
        # use known propensity score
        p = check_shape_1d_data(p)
        pi_hat = p

    # second stage
    y, w = check_shape_1d_data(y), check_shape_1d_data(w)
    # transform data and fit on transformed data
    if transformation is HT_TRANSFORMATION:
        mu_0 = None
        mu_1 = None

    pseudo_outcome = transformation_function(y=y, w=w, p=pi_hat, mu_0=mu_0, mu_1=mu_1)
    if rescale_transformation:
        scale_factor = onp.std(y) / onp.std(pseudo_outcome)
        if scale_factor > 1:
            scale_factor = 1
        else:
            pseudo_outcome = scale_factor * pseudo_outcome
        params, predict_funs = train_output_net_only(X, pseudo_outcome, binary_y=False,
                                                     n_layers_out=n_layers_out_t,
                                                     n_units_out=n_units_out_t,
                                                     n_layers_r=n_layers_r_t,
                                                     n_units_r=n_units_r_t,
                                                     penalty_l2=penalty_l2_t,
                                                     step_size=step_size_t,
                                                     n_iter=n_iter,
                                                     batch_size=batch_size,
                                                     val_split_prop=val_split_prop,
                                                     early_stopping=early_stopping,
                                                     patience=patience,
                                                     n_iter_min=n_iter_min,
                                                     n_iter_print=n_iter_print,
                                                     verbose=verbose,
                                                     seed=seed,
                                                     return_val_loss=return_val_loss,
                                                     nonlin=nonlin,
                                                     avg_objective=avg_objective)
        return params, predict_funs, scale_factor
    else:
        return train_output_net_only(X, pseudo_outcome, binary_y=False,
                                     n_layers_out=n_layers_out_t,
                                     n_units_out=n_units_out_t,
                                     n_layers_r=n_layers_r_t,
                                     n_units_r=n_units_r_t,
                                     penalty_l2=penalty_l2_t,
                                     step_size=step_size_t,
                                     n_iter=n_iter,
                                     batch_size=batch_size,
                                     val_split_prop=val_split_prop,
                                     early_stopping=early_stopping,
                                     patience=patience,
                                     n_iter_min=n_iter_min,
                                     n_iter_print=n_iter_print,
                                     verbose=verbose,
                                     seed=seed,
                                     return_val_loss=return_val_loss, nonlin=nonlin,
                                     avg_objective=avg_objective)
예제 #7
0
def train_x_net(X,
                y,
                w,
                weight_strategy: int = None,
                binary_y: bool = False,
                n_layers_out: int = DEFAULT_LAYERS_OUT,
                n_layers_r: int = DEFAULT_LAYERS_R,
                n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
                n_layers_r_t: int = DEFAULT_LAYERS_R_T,
                n_units_out: int = DEFAULT_UNITS_OUT,
                n_units_r: int = DEFAULT_UNITS_R,
                n_units_out_t: int = DEFAULT_UNITS_OUT_T,
                n_units_r_t: int = DEFAULT_UNITS_R_T,
                penalty_l2: float = DEFAULT_PENALTY_L2,
                penalty_l2_t: float = DEFAULT_PENALTY_L2,
                step_size: float = DEFAULT_STEP_SIZE,
                step_size_t: float = DEFAULT_STEP_SIZE_T,
                n_iter: int = DEFAULT_N_ITER,
                batch_size: int = DEFAULT_BATCH_SIZE,
                n_iter_min: int = DEFAULT_N_ITER_MIN,
                val_split_prop: float = DEFAULT_VAL_SPLIT,
                early_stopping: bool = True,
                patience: int = DEFAULT_PATIENCE,
                verbose: int = 1,
                n_iter_print: int = DEFAULT_N_ITER_PRINT,
                seed: int = DEFAULT_SEED,
                nonlin: str = DEFAULT_NONLIN,
                return_val_loss: bool = False,
                avg_objective: bool = DEFAULT_AVG_OBJECTIVE):
    y = check_shape_1d_data(y)
    if len(w.shape) > 1:
        w = w.reshape((len(w), ))

    if weight_strategy not in [0, 1, -1, None]:
        # weight_strategy is coded as follows:
        # for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)]
        # weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1,
        # weight_strategy=None sets g(x)=pi(x) [propensity score],
        # weight_strategy=-1 sets g(x)=(1-pi(x))
        raise ValueError(
            'XNet only implements weight_strategy in [0, 1, -1, None]')

    # first stage: get estimates of PO regression
    if verbose > 0:
        print("Training first stage")

    if not weight_strategy == 1:
        if verbose > 0:
            print('Training PO_0 Net')
        params_0, predict_fun_0 = train_output_net_only(
            X[w == 0],
            y[w == 0],
            binary_y=binary_y,
            n_layers_out=n_layers_out,
            n_units_out=n_units_out,
            n_layers_r=n_layers_r,
            n_units_r=n_units_r,
            penalty_l2=penalty_l2,
            step_size=step_size,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            n_iter_print=n_iter_print,
            verbose=verbose,
            seed=seed,
            nonlin=nonlin,
            avg_objective=avg_objective)
        mu_hat_0 = predict_fun_0(params_0, X[w == 1])
    else:
        mu_hat_0 = None

    if not weight_strategy == 0:
        if verbose > 0:
            print('Training PO_1 Net')
        params_1, predict_fun_1 = train_output_net_only(
            X[w == 1],
            y[w == 1],
            binary_y=binary_y,
            n_layers_out=n_layers_out,
            n_units_out=n_units_out,
            n_layers_r=n_layers_r,
            n_units_r=n_units_r,
            penalty_l2=penalty_l2,
            step_size=step_size,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            n_iter_print=n_iter_print,
            verbose=verbose,
            seed=seed,
            nonlin=nonlin,
            avg_objective=avg_objective)
        mu_hat_1 = predict_fun_1(params_1, X[w == 0])
    else:
        mu_hat_1 = None

    if weight_strategy is None or weight_strategy == -1:
        # also fit propensity estimator
        if verbose > 0:
            print('Training propensity net')
        params_prop, predict_fun_prop = train_output_net_only(
            X,
            w,
            binary_y=True,
            n_layers_out=n_layers_out,
            n_units_out=n_units_out,
            n_layers_r=n_layers_r,
            n_units_r=n_units_r,
            penalty_l2=penalty_l2,
            step_size=step_size,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            n_iter_print=n_iter_print,
            verbose=verbose,
            seed=seed,
            nonlin=nonlin,
            avg_objective=avg_objective)

    else:
        params_prop, predict_fun_prop = None, None

    # second stage
    if verbose > 0:
        print("Training second stage")
    if not weight_strategy == 0:
        # fit tau_0
        if verbose > 0:
            print("Fitting tau_0")
        pseudo_outcome0 = mu_hat_1 - y[w == 0]
        params_tau0, predict_fun_tau0 = train_output_net_only(
            X[w == 0],
            pseudo_outcome0,
            binary_y=False,
            n_layers_out=n_layers_out_t,
            n_units_out=n_units_out_t,
            n_layers_r=n_layers_r_t,
            n_units_r=n_units_r_t,
            penalty_l2=penalty_l2_t,
            step_size=step_size_t,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            n_iter_print=n_iter_print,
            verbose=verbose,
            seed=seed,
            return_val_loss=return_val_loss,
            nonlin=nonlin,
            avg_objective=avg_objective)
    else:
        params_tau0, predict_fun_tau0 = None, None

    if not weight_strategy == 1:
        # fit tau_1
        if verbose > 0:
            print("Fitting tau_1")
        pseudo_outcome1 = y[w == 1] - mu_hat_0
        params_tau1, predict_fun_tau1 = train_output_net_only(
            X[w == 1],
            pseudo_outcome1,
            binary_y=False,
            n_layers_out=n_layers_out_t,
            n_units_out=n_units_out_t,
            n_layers_r=n_layers_r_t,
            n_units_r=n_units_r_t,
            penalty_l2=penalty_l2_t,
            step_size=step_size_t,
            n_iter=n_iter,
            batch_size=batch_size,
            val_split_prop=val_split_prop,
            early_stopping=early_stopping,
            patience=patience,
            n_iter_min=n_iter_min,
            n_iter_print=n_iter_print,
            verbose=verbose,
            seed=seed,
            return_val_loss=return_val_loss,
            nonlin=nonlin,
            avg_objective=avg_objective)

    else:
        params_tau1, predict_fun_tau1 = None, None

    params = params_tau0, params_tau1, params_prop
    predict_funs = predict_fun_tau0, predict_fun_tau1, predict_fun_prop

    return params, predict_funs