def loss_snet2(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1, head_prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps = predict_fun_repr(params[0], X)

        # pass down to heads
        loss_0 = loss_head(params[1], (reps, y, 1 - w))
        loss_1 = loss_head(params[2], (reps, y, w))

        # pass down to propensity head
        loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2)
        weightsq_prop = sum([
            jnp.sum(params[3][i][0]**2)
            for i in range(0, 2 * n_layers_out + 1, 2)
        ])

        weightsq_body = sum(
            [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)])
        weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + 0.5 * (
                penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + loss_prop / n_batch + \
                   0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
    def loss_snet1(params, batch, penalty_l2, penalty_disc, penalty_diff):
        # params: list[representation, head_0, head_1]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps = predict_fun_repr(params[0], X)

        # get mmd
        disc = mmd2_lin(reps, w)

        # pass down to two heads
        loss_0 = loss_head(params[1], (reps, y, 1 - w))
        loss_1 = loss_head(params[2], (reps, y, w))

        # regularization on representation
        weightsq_body = sum(
            [jnp.sum(params[0][i][0]**2) for i in range(0, 2 * n_layers_r, 2)])
        weightsq_head = heads_l2_penalty(params[1], params[2], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)
        if not avg_objective:
            return loss_0 + loss_1 + penalty_disc * disc + \
                   0.5 * (penalty_l2 * weightsq_body + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + penalty_disc * disc + \
                   0.5 * (penalty_l2 * weightsq_body + weightsq_head)
Example #3
0
    def loss_snet(params, batch, penalty_l2, penalty_orthogonal, penalty_disc):
        # params: # param should look like [param_repr_c, param_repr_o, param_repr_mu0,
        #              param_repr_mu1, param_repr_w, param_0, param_1, param_prop]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps_c = predict_fun_repr(params[0], X)
        reps_o = predict_fun_repr_small(params[1], X)
        reps_mu0 = predict_fun_repr_small(params[2], X)
        reps_mu1 = predict_fun_repr_small(params[3], X)
        reps_w = predict_fun_repr(params[4], X)

        # concatenate
        reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0))
        reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1))
        reps_prop = _concatenate_representations((reps_c, reps_w))

        # pass down to heads
        loss_0 = loss_head(params[5], (reps_po_0, y, 1 - w), penalty_l2)
        loss_1 = loss_head(params[6], (reps_po_1, y, w), penalty_l2)

        # pass down to propensity head
        loss_prop = loss_head_prop(params[7], (reps_prop, w), penalty_l2)

        # is rep_o balanced between groups?
        loss_disc = penalty_disc * mmd2_lin(reps_o, w)

        # which variable has impact on which representation -- orthogonal loss
        col_c = _get_absolute_rowsums(params[0][0][0])
        col_o = _get_absolute_rowsums(params[1][0][0])
        col_mu0 = _get_absolute_rowsums(params[2][0][0])
        col_mu1 = _get_absolute_rowsums(params[3][0][0])
        col_w = _get_absolute_rowsums(params[4][0][0])
        loss_o = penalty_orthogonal * (jnp.sum(
            col_c * col_o + col_c * col_w + col_c * col_mu1 + col_c * col_mu0 +
            col_w * col_o + col_mu0 * col_o + col_o * col_mu1 +
            col_mu0 * col_mu1 + col_mu0 * col_w + col_w * col_mu1))

        # weight decay on representations
        weightsq_body = sum([
            sum([
                jnp.sum(params[j][i][0]**2)
                for i in range(0, 2 * n_layers_r, 2)
            ]) for j in range(5)
        ])
        weightsq_head = heads_l2_penalty(params[5], params[6], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)
        weightsq_prop = sum([
            jnp.sum(params[7][i][0]**2)
            for i in range(0, 2 * n_layers_out + 1, 2)
        ])

        if not avg_objective:
            return loss_0 + loss_1 + loss_prop + loss_disc + loss_o + 0.5 * (
                penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1)/n_batch + loss_prop/n_batch + loss_disc + loss_o\
                   + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
Example #4
0
    def loss_tnet(params, batch, penalty_l2, penalty_diff):
        # params: list[representation, head_0, head_1]
        # batch: (X, y, w)
        X, y, w = batch

        # pass down to two heads
        loss_0 = loss_head(params[0], (X, y, 1 - w))
        loss_1 = loss_head(params[1], (X, y, w))

        # regularization
        weightsq_head = heads_l2_penalty(params[0], params[1],
                                         n_layers_r + n_layers_out, True,
                                         penalty_l2, penalty_diff)
        if not avg_objective:
            return loss_0 + loss_1 + 0.5 * (weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head)
Example #5
0
    def loss_snet_noprop(params, batch, penalty_l2, penalty_orthogonal):
        # params: list[repr_o, repr_p0, repr_p1, po_0, po_1]
        # batch: (X, y, w)
        X, y, w = batch

        # get representation
        reps_o = predict_fun_repr(params[0], X)
        reps_p0 = predict_fun_repr_small(params[1], X)
        reps_p1 = predict_fun_repr_small(params[2], X)

        # concatenate
        reps_po0 = _concatenate_representations((reps_o, reps_p0))
        reps_po1 = _concatenate_representations((reps_o, reps_p1))

        # pass down to heads
        loss_0 = loss_head(params[3], (reps_po0, y, 1 - w), penalty_l2)
        loss_1 = loss_head(params[4], (reps_po1, y, w), penalty_l2)

        # which variable has impact on which representation
        col_o = _get_absolute_rowsums(params[0][0][0])
        col_p0 = _get_absolute_rowsums(params[1][0][0])
        col_p1 = _get_absolute_rowsums(params[2][0][0])
        loss_o = penalty_orthogonal * (
            jnp.sum(col_o * col_p0 + col_o * col_p1 + col_p1 * col_p0))

        # weight decay on representations
        weightsq_body = sum([
            sum([
                jnp.sum(params[j][i][0]**2)
                for i in range(0, 2 * n_layers_r, 2)
            ]) for j in range(3)
        ])
        weightsq_head = heads_l2_penalty(params[3], params[4], n_layers_out,
                                         reg_diff, penalty_l2, penalty_diff)
        if not avg_objective:
            return loss_0 + loss_1 + loss_o + 0.5 * (
                penalty_l2 * weightsq_body + weightsq_head)
        else:
            n_batch = y.shape[0]
            return (loss_0 + loss_1)/n_batch + loss_o + \
                   0.5 * (penalty_l2 * weightsq_body + weightsq_head)