def loss_snet2(params, batch, penalty_l2, penalty_diff): # params: list[representation, head_0, head_1, head_prop] # batch: (X, y, w) X, y, w = batch # get representation reps = predict_fun_repr(params[0], X) # pass down to heads loss_0 = loss_head(params[1], (reps, y, 1 - w)) loss_1 = loss_head(params[2], (reps, y, w)) # pass down to propensity head loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2) weightsq_prop = sum([ jnp.sum(params[3][i][0]**2) for i in range(0, 2 * n_layers_out + 1, 2) ]) weightsq_body = sum( [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)]) weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + loss_prop + 0.5 * ( penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
def loss_snet1(params, batch, penalty_l2, penalty_disc, penalty_diff): # params: list[representation, head_0, head_1] # batch: (X, y, w) X, y, w = batch # get representation reps = predict_fun_repr(params[0], X) # get mmd disc = mmd2_lin(reps, w) # pass down to two heads loss_0 = loss_head(params[1], (reps, y, 1 - w)) loss_1 = loss_head(params[2], (reps, y, w)) # regularization on representation weightsq_body = sum( [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)]) weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + penalty_disc * disc + \ 0.5 * (penalty_l2 * weightsq_body + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + penalty_disc * disc + \ 0.5 * (penalty_l2 * weightsq_body + weightsq_head)
def loss_snet(params, batch, penalty_l2, penalty_orthogonal, penalty_disc): # params: # param should look like [param_repr_c, param_repr_o, param_repr_mu0, # param_repr_mu1, param_repr_w, param_0, param_1, param_prop] # batch: (X, y, w) X, y, w = batch # get representation reps_c = predict_fun_repr(params[0], X) reps_o = predict_fun_repr_small(params[1], X) reps_mu0 = predict_fun_repr_small(params[2], X) reps_mu1 = predict_fun_repr_small(params[3], X) reps_w = predict_fun_repr(params[4], X) # concatenate reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0)) reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1)) reps_prop = _concatenate_representations((reps_c, reps_w)) # pass down to heads loss_0 = loss_head(params[5], (reps_po_0, y, 1 - w), penalty_l2) loss_1 = loss_head(params[6], (reps_po_1, y, w), penalty_l2) # pass down to propensity head loss_prop = loss_head_prop(params[7], (reps_prop, w), penalty_l2) # is rep_o balanced between groups? loss_disc = penalty_disc * mmd2_lin(reps_o, w) # which variable has impact on which representation -- orthogonal loss col_c = _get_absolute_rowsums(params[0][0][0]) col_o = _get_absolute_rowsums(params[1][0][0]) col_mu0 = _get_absolute_rowsums(params[2][0][0]) col_mu1 = _get_absolute_rowsums(params[3][0][0]) col_w = _get_absolute_rowsums(params[4][0][0]) loss_o = penalty_orthogonal * (jnp.sum( col_c * col_o + col_c * col_w + col_c * col_mu1 + col_c * col_mu0 + col_w * col_o + col_mu0 * col_o + col_o * col_mu1 + col_mu0 * col_mu1 + col_mu0 * col_w + col_w * col_mu1)) # weight decay on representations weightsq_body = sum([ sum([ jnp.sum(params[j][i][0]**2) for i in range(0, 2 * n_layers_r, 2) ]) for j in range(5) ]) weightsq_head = heads_l2_penalty(params[5], params[6], n_layers_out, reg_diff, penalty_l2, penalty_diff) weightsq_prop = sum([ jnp.sum(params[7][i][0]**2) for i in range(0, 2 * n_layers_out + 1, 2) ]) if not avg_objective: return loss_0 + loss_1 + loss_prop + loss_disc + loss_o + 0.5 * ( penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_disc + loss_o\ + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
def loss_tnet(params, batch, penalty_l2, penalty_diff): # params: list[representation, head_0, head_1] # batch: (X, y, w) X, y, w = batch # pass down to two heads loss_0 = loss_head(params[0], (X, y, 1 - w)) loss_1 = loss_head(params[1], (X, y, w)) # regularization weightsq_head = heads_l2_penalty(params[0], params[1], n_layers_r + n_layers_out, True, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + 0.5 * (weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head)
def loss_snet_noprop(params, batch, penalty_l2, penalty_orthogonal): # params: list[repr_o, repr_p0, repr_p1, po_0, po_1] # batch: (X, y, w) X, y, w = batch # get representation reps_o = predict_fun_repr(params[0], X) reps_p0 = predict_fun_repr_small(params[1], X) reps_p1 = predict_fun_repr_small(params[2], X) # concatenate reps_po0 = _concatenate_representations((reps_o, reps_p0)) reps_po1 = _concatenate_representations((reps_o, reps_p1)) # pass down to heads loss_0 = loss_head(params[3], (reps_po0, y, 1 - w), penalty_l2) loss_1 = loss_head(params[4], (reps_po1, y, w), penalty_l2) # which variable has impact on which representation col_o = _get_absolute_rowsums(params[0][0][0]) col_p0 = _get_absolute_rowsums(params[1][0][0]) col_p1 = _get_absolute_rowsums(params[2][0][0]) loss_o = penalty_orthogonal * ( jnp.sum(col_o * col_p0 + col_o * col_p1 + col_p1 * col_p0)) # weight decay on representations weightsq_body = sum([ sum([ jnp.sum(params[j][i][0]**2) for i in range(0, 2 * n_layers_r, 2) ]) for j in range(3) ]) weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff) if not avg_objective: return loss_0 + loss_1 + loss_o + 0.5 * ( penalty_l2 * weightsq_body + weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1)/n_batch + loss_o + \ 0.5 * (penalty_l2 * weightsq_body + weightsq_head)