def train_r_stage2(X, y_ortho, w_ortho, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = 0, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE): # function to train a single output head # input check y_ortho, w_ortho = check_shape_1d_data(y_ortho), check_shape_1d_data( w_ortho) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y_ortho, w_ortho, X_val, y_val, w_val, val_string = make_val_split( X, y_ortho, w_ortho, val_split_prop=val_split_prop, seed=seed, stratify_w=False) n = X.shape[0] # could be different from before due to split # get output head init_fun, predict_fun = OutputHead(n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin) # define loss and grad @jit def loss(params, batch, penalty): # mse loss function inputs, ortho_targets, ortho_treats = batch preds = predict_fun(params, inputs) weightsq = sum([ jnp.sum(params[i][0]**2) for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2) ]) if not avg_objective: return jnp.sum((ortho_targets - ortho_treats * preds) ** 2) + \ 0.5 * penalty * weightsq else: return jnp.average((ortho_targets - ortho_treats * preds) ** 2) + \ 0.5 * penalty * weightsq # set optimization routine # set optimizer opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) # set update function @jit def update(i, state, batch, penalty): params = get_params(state) g_params = grad(loss)(params, batch, penalty) # g_params = optimizers.clip_grads(g_params, 1.0) return opt_update(i, g_params, state) # initialise states _, init_params = init_fun(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)] next_batch = X[idx_next, :], y_ortho[idx_next, :], w_ortho[ idx_next, :] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2) if (verbose > 0 and i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss(params_curr, (X_val, y_val, w_val), penalty_l2) if verbose > 0 and i % n_iter_print == 0: print("Epoch: {}, current {} loss: {}".format( i, val_string, l_curr)) if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val, w_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun # get final parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val, w_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun
def _train_tnet_jointly(X, y, w, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, same_init: bool = True, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE): # input check y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed) n = X.shape[0] # could be different from before due to split # get output head functions (both heads share same structure) init_fun_head, predict_fun_head = OutputHead(n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin) # Define loss functions # loss functions for the head if not binary_y: def loss_head(params, batch): # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return jnp.sum(weights * ((preds - targets)**2)) else: def loss_head(params, batch): # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return -jnp.sum(weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))) @jit def loss_tnet(params, batch, penalty_l2, penalty_diff): # params: list[representation, head_0, head_1] # batch: (X, y, w) X, y, w = batch # pass down to two heads loss_0 = loss_head(params[0], (X, y, 1 - w)) loss_1 = loss_head(params[1], (X, y, w)) # regularization weightsq_head = heads_l2_penalty(params[0], params[1], n_layers_r + n_layers_out, True, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + 0.5 * (weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update(i, state, batch, penalty_l2, penalty_diff): # updating function params = get_params(state) return opt_update( i, grad(loss_tnet)(params, batch, penalty_l2, penalty_diff), state) # initialise states if same_init: _, init_head = init_fun_head(rng_key, input_shape) init_params = [init_head, init_head] else: rng_key, rng_key_2 = random.split(rng_key) _, init_head_0 = init_fun_head(rng_key, input_shape) _, init_head_1 = init_fun_head(rng_key_2, input_shape) init_params = [init_head_0, init_head_1] opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff) if (verbose > 0 and i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_tnet(params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff) if verbose > 0 and i % n_iter_print == 0: print("Epoch: {}, current {} loss {}".format( i, val_string, l_curr)) if early_stopping and ((i + 1) * n_batches > n_iter_min): if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, predict_fun_head p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_tnet(params_curr, (X_val, y_val, w_val), 0, 0) return params_curr, predict_fun_head, l_final return params_curr, predict_fun_head # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_tnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, predict_fun_head, l_final return trained_params, predict_fun_head
def train_r_net(X, y, w, p=None, second_stage_strategy: str = R_STRATEGY_NAME, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN): # get shape of data n, d = X.shape if p is not None: p = check_shape_1d_data(p) # split data as wanted if not cross_fit: if not data_split: if verbose > 0: print('Training first stage with all data (no data splitting)') # use all data for both fit_mask = onp.ones(n, dtype=bool) pred_mask = onp.ones(n, dtype=bool) else: if verbose > 0: print( 'Training first stage with half of the data (data splitting)' ) # split data in half fit_idx = onp.random.choice(n, int(onp.round(n / 2))) fit_mask = onp.zeros(n, dtype=bool) fit_mask[fit_idx] = 1 pred_mask = ~fit_mask mu_hat, pi_hat = _train_and_predict_r_stage1( X, y, w, fit_mask, pred_mask, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin) if data_split: # keep only prediction data X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :] if p is not None: p = p[pred_mask, :] else: if verbose > 0: print('Training first stage in {} folds (cross-fitting)'.format( n_cf_folds)) # do cross fitting mu_hat, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)) splitter = StratifiedKFold(n_splits=n_cf_folds, shuffle=True, random_state=seed) fold_count = 1 for train_idx, test_idx in splitter.split(X, w): if verbose > 0: print('Training fold {}.'.format(fold_count)) fold_count = fold_count + 1 pred_mask = onp.zeros(n, dtype=bool) pred_mask[test_idx] = 1 fit_mask = ~pred_mask mu_hat[pred_mask], pi_hat[pred_mask] = \ _train_and_predict_r_stage1(X, y, w, fit_mask, pred_mask, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin) if verbose > 0: print('Training second stage.') if p is not None: # use known propensity score p = check_shape_1d_data(p) pi_hat = p y, w = check_shape_1d_data(y), check_shape_1d_data(w) w_ortho = w - pi_hat y_ortho = y - mu_hat if second_stage_strategy == R_STRATEGY_NAME: return train_r_stage2(X, y_ortho, w_ortho, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin) elif second_stage_strategy == U_STRATEGY_NAME: return train_output_net_only(X, y_ortho / w_ortho, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin) else: raise ValueError('R-learner only supports strategies R and U.')
def train_snet2(X, y, w, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, same_init: bool = False): """ SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term]. """ y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed) n = X.shape[0] # could be different from before due to split # get representation layer init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin) # add propensity head init_fun_head_prop, predict_fun_head_prop = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=True, nonlin=nonlin) def init_fun_snet2(rng, input_shape): # chain together the layers # param should look like [repr, po_0, po_1, prop] rng, layer_rng = random.split(rng) input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr) return input_shape, [param_repr, param_0, param_1, param_prop] # Define loss functions # loss functions for the head if not binary_y: def loss_head(params, batch): # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets)**2)) else: def loss_head(params, batch): # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return -jnp.sum(weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))) def loss_head_prop(params, batch, penalty): # log loss function for propensities inputs, targets = batch preds = predict_fun_head_prop(params, inputs) return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) # complete loss function for all parts @jit def loss_snet2(params, batch, penalty_l2, penalty_diff): # params: list[representation, head_0, head_1, head_prop] # batch: (X, y, w) X, y, w = batch # get representation reps = predict_fun_repr(params[0], X) # pass down to heads loss_0 = loss_head(params[1], (reps, y, 1 - w)) loss_1 = loss_head(params[2], (reps, y, w)) # pass down to propensity head loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2) weightsq_prop = sum([ jnp.sum(params[3][i][0]**2) for i in range(0, 2 * n_layers_out + 1, 2) ]) weightsq_body = sum( [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)]) weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + loss_prop + 0.5 * ( penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update(i, state, batch, penalty_l2, penalty_diff): # updating function params = get_params(state) return opt_update( i, grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state) # initialise states _, init_params = init_fun_snet2(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff) if (verbose > 0 and i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet2(params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff) if verbose > 0 and i % n_iter_print == 0: print("Epoch: {}, current {} loss {}".format( i, val_string, l_curr)) if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0, 0) return params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final return params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \ l_final return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop)
def train_snet3(X, y, w, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S3, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN, patience: int = DEFAULT_PATIENCE, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, same_init: bool = False): """ SNet-3, based on the decompostion used in Hassanpour and Greiner (2020) """ # function to train a net with 3 representations y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split(X, y, w, val_split_prop=val_split_prop, seed=seed) n = X.shape[0] # could be different from before due to split # get representation layers init_fun_repr, predict_fun_repr = ReprBlock(n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin) init_fun_repr_small, predict_fun_repr_small = ReprBlock(n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead(n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin) # add propensity head init_fun_head_prop, predict_fun_head_prop = OutputHead(n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=True, nonlin=nonlin) def init_fun_snet3(rng, input_shape): # chain together the layers # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop] # initialise representation layers rng, layer_rng = random.split(rng) input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) input_shape_repr_small, param_repr_o = init_fun_repr_small(layer_rng, input_shape) rng, layer_rng = random.split(rng) _, param_repr_w = init_fun_repr_small(layer_rng, input_shape) # each head gets two representations input_shape_repr = input_shape_repr[:-1] + (input_shape_repr[-1] + input_shape_repr_small[ -1],) # initialise output heads rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr) return input_shape, [param_repr_c, param_repr_o, param_repr_w, param_0, param_1, param_prop] # Define loss functions # loss functions for the head if not binary_y: def loss_head(params, batch, penalty): # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head(params, batch, penalty): # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return - jnp.sum(weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log( 1 - preds))) def loss_head_prop(params, batch, penalty): # log loss function for propensities inputs, targets = batch preds = predict_fun_head_prop(params, inputs) return - jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log( 1 - preds)) # complete loss function for all parts @jit def loss_snet3(params, batch, penalty_l2, penalty_orthogonal, penalty_disc): # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop] # batch: (X, y, w) X, y, w = batch # get representation reps_c = predict_fun_repr(params[0], X) reps_o = predict_fun_repr_small(params[1], X) reps_w = predict_fun_repr_small(params[2], X) # concatenate reps_po = _concatenate_representations((reps_c, reps_o)) reps_prop = _concatenate_representations((reps_c, reps_w)) # pass down to heads loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2) loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2) # pass down to propensity head loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2) weightsq_prop = sum([jnp.sum(params[5][i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)]) # which variable has impact on which representation col_c = _get_absolute_rowsums(params[0][0][0]) col_o = _get_absolute_rowsums(params[1][0][0]) col_w = _get_absolute_rowsums(params[2][0][0]) loss_o = penalty_orthogonal * ( jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o)) # is rep_o balanced between groups? loss_disc = penalty_disc * mmd2_lin(reps_o, w) # weight decay on representations weightsq_body = sum([sum([jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]) for j in range(3)]) weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + loss_prop + loss_o + loss_disc + \ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_o + loss_disc + \ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update(i, state, batch, penalty_l2, penalty_orthogonal, penalty_disc): # updating function params = get_params(state) return opt_update(i, grad(loss_snet3)( params, batch, penalty_l2, penalty_orthogonal, penalty_disc), state) # initialise states _, init_params = init_fun_snet3(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[(b * batch_size):min((b + 1) * batch_size, n - 1)] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_orthogonal, penalty_disc) if (verbose > 0 and i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet3(params_curr, (X_val, y_val, w_val), penalty_l2, penalty_orthogonal, penalty_disc) if verbose > 0 and i % n_iter_print == 0: print("Epoch: {}, current {} loss {}".format(i, val_string, l_curr)) if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0, 0, 0) return params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final return params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), \ l_final return trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop)
def train_twostep_net(X, y, w, p=None, first_stage_strategy: str = T_STRATEGY, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, transformation: str = AIPW_TRANSFORMATION, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, return_val_loss: bool = False, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE): # get shape of data n, d = X.shape if p is not None: p = check_shape_1d_data(p) # get transformation function transformation_function = _get_transformation_function(transformation) # get strategy name if first_stage_strategy not in ALL_STRATEGIES: raise ValueError('Parameter first stage should be in ' 'catenets.models.twostep_nets.ALL_STRATEGIES. ' 'You passed {}'.format(first_stage_strategy)) # split data as wanted if p is None or transformation is not HT_TRANSFORMATION: if not cross_fit: if not data_split: if verbose > 0: print('Training first stage with all data (no data splitting)') # use all data for both fit_mask = onp.ones(n, dtype=bool) pred_mask = onp.ones(n, dtype=bool) else: if verbose > 0: print('Training first stage with half of the data (data splitting)') # split data in half fit_idx = onp.random.choice(n, int(onp.round(n / 2))) fit_mask = onp.zeros(n, dtype=bool) fit_mask[fit_idx] = 1 pred_mask = ~ fit_mask mu_0, mu_1, pi_hat = _train_and_predict_first_stage(X, y, w, fit_mask, pred_mask, first_stage_strategy=first_stage_strategy, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, penalty_orthogonal=penalty_orthogonal, n_units_r_small=n_units_r_small, nonlin=nonlin, avg_objective=avg_objective, transformation=transformation) if data_split: # keep only prediction data X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :] if p is not None: p = p[pred_mask, :] else: if verbose > 0: print('Training first stage in {} folds (cross-fitting)'.format(n_cf_folds)) # do cross fitting mu_0, mu_1, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)), onp.zeros((n, 1)) splitter = StratifiedKFold(n_splits=n_cf_folds, shuffle=True, random_state=seed) fold_count = 1 for train_idx, test_idx in splitter.split(X, w): if verbose > 0: print('Training fold {}.'.format(fold_count)) fold_count = fold_count + 1 pred_mask = onp.zeros(n, dtype=bool) pred_mask[test_idx] = 1 fit_mask = ~ pred_mask mu_0[pred_mask], mu_1[pred_mask], pi_hat[pred_mask] = \ _train_and_predict_first_stage(X, y, w, fit_mask, pred_mask, first_stage_strategy=first_stage_strategy, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, verbose=verbose, n_iter_print=n_iter_print, seed=seed, penalty_orthogonal=penalty_orthogonal, n_units_r_small=n_units_r_small, nonlin=nonlin, avg_objective=avg_objective, transformation=transformation) if verbose > 0: print('Training second stage.') if p is not None: # use known propensity score p = check_shape_1d_data(p) pi_hat = p # second stage y, w = check_shape_1d_data(y), check_shape_1d_data(w) # transform data and fit on transformed data if transformation is HT_TRANSFORMATION: mu_0 = None mu_1 = None pseudo_outcome = transformation_function(y=y, w=w, p=pi_hat, mu_0=mu_0, mu_1=mu_1) if rescale_transformation: scale_factor = onp.std(y) / onp.std(pseudo_outcome) if scale_factor > 1: scale_factor = 1 else: pseudo_outcome = scale_factor * pseudo_outcome params, predict_funs = train_output_net_only(X, pseudo_outcome, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective) return params, predict_funs, scale_factor else: return train_output_net_only(X, pseudo_outcome, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective)
def train_x_net(X, y, w, weight_strategy: int = None, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, verbose: int = 1, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, return_val_loss: bool = False, avg_objective: bool = DEFAULT_AVG_OBJECTIVE): y = check_shape_1d_data(y) if len(w.shape) > 1: w = w.reshape((len(w), )) if weight_strategy not in [0, 1, -1, None]: # weight_strategy is coded as follows: # for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)] # weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1, # weight_strategy=None sets g(x)=pi(x) [propensity score], # weight_strategy=-1 sets g(x)=(1-pi(x)) raise ValueError( 'XNet only implements weight_strategy in [0, 1, -1, None]') # first stage: get estimates of PO regression if verbose > 0: print("Training first stage") if not weight_strategy == 1: if verbose > 0: print('Training PO_0 Net') params_0, predict_fun_0 = train_output_net_only( X[w == 0], y[w == 0], binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, nonlin=nonlin, avg_objective=avg_objective) mu_hat_0 = predict_fun_0(params_0, X[w == 1]) else: mu_hat_0 = None if not weight_strategy == 0: if verbose > 0: print('Training PO_1 Net') params_1, predict_fun_1 = train_output_net_only( X[w == 1], y[w == 1], binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, nonlin=nonlin, avg_objective=avg_objective) mu_hat_1 = predict_fun_1(params_1, X[w == 0]) else: mu_hat_1 = None if weight_strategy is None or weight_strategy == -1: # also fit propensity estimator if verbose > 0: print('Training propensity net') params_prop, predict_fun_prop = train_output_net_only( X, w, binary_y=True, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, nonlin=nonlin, avg_objective=avg_objective) else: params_prop, predict_fun_prop = None, None # second stage if verbose > 0: print("Training second stage") if not weight_strategy == 0: # fit tau_0 if verbose > 0: print("Fitting tau_0") pseudo_outcome0 = mu_hat_1 - y[w == 0] params_tau0, predict_fun_tau0 = train_output_net_only( X[w == 0], pseudo_outcome0, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective) else: params_tau0, predict_fun_tau0 = None, None if not weight_strategy == 1: # fit tau_1 if verbose > 0: print("Fitting tau_1") pseudo_outcome1 = y[w == 1] - mu_hat_0 params_tau1, predict_fun_tau1 = train_output_net_only( X[w == 1], pseudo_outcome1, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, verbose=verbose, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective) else: params_tau1, predict_fun_tau1 = None, None params = params_tau0, params_tau1, params_prop predict_funs = predict_fun_tau0, predict_fun_tau1, predict_fun_prop return params, predict_funs