Exemple #1
0
    def test_match_sklearn(self):
        clf_sklearn = linear_model.LogisticRegression(penalty='none')
        clf_sklearn.fit(self.data_x, self.data_y)

        self.assertAlmostEqual(clf_sklearn.intercept_[0], -0.0271, places=3)
        self.assertAlmostEqual(clf_sklearn.coef_[0][0], 1.0987, places=3)
        self.assertAlmostEqual(clf_sklearn.coef_[0][1], -1.0216, places=3)

        model = train_proxy.LogisticReg()

        erm_loss = train_proxy.make_loss_func(model,
                                              self.data,
                                              erm_weight=1.,
                                              bias_lamb=0.)

        init_params = train_proxy.initialize_params(model,
                                                    mdim=self.data_x.shape[1],
                                                    seed=0)

        erm_params, _ = train_proxy.train(erm_loss,
                                          init_params,
                                          lr=1.,
                                          nsteps=1000)

        b, w = jax.tree_leaves(erm_params)
        self.assertAlmostEqual(w[0].item(), 1.0988, places=3)
        self.assertAlmostEqual(w[1].item(), -1.0216, places=3)
        self.assertAlmostEqual(b.item(), -0.0274, places=3)
Exemple #2
0
    def test_policy_bias_regularization(self):
        model = train_proxy.LogisticReg()

        mix_loss = train_proxy.make_loss_func(model,
                                              self.data,
                                              erm_weight=1.,
                                              bias_lamb=10.)

        init_params = train_proxy.initialize_params(model,
                                                    mdim=self.data_x.shape[1],
                                                    seed=0)

        mix_params, _ = train_proxy.train(mix_loss,
                                          init_params,
                                          lr=1.,
                                          nsteps=1000)

        b, w = jax.tree_leaves(mix_params)
        self.assertAlmostEqual(w[0].item(), 1.4478, places=3)
        self.assertAlmostEqual(w[1].item(), -1.2915, places=3)
        self.assertAlmostEqual(b.item(), 0.1693, places=3)
Exemple #3
0
    def test_l2_regularization(self):
        model = train_proxy.LogisticReg()

        erm_loss_reg = train_proxy.make_loss_func(model,
                                                  self.data,
                                                  erm_weight=1.,
                                                  l2_lamb=10.,
                                                  bias_lamb=0.)

        init_params = train_proxy.initialize_params(model,
                                                    mdim=self.data_x.shape[1],
                                                    seed=0)

        erm_params_reg, _ = train_proxy.train(erm_loss_reg,
                                              init_params,
                                              lr=1.,
                                              nsteps=1000)

        b, w = jax.tree_leaves(erm_params_reg)
        self.assertAlmostEqual(w[0].item(), 0.0030, places=3)
        self.assertAlmostEqual(w[1].item(), -0.0028, places=3)
        self.assertAlmostEqual(b.item(), 0.0158, places=3)
def load_and_train():
    """Load data from file and return checkpoints from training."""
    simulation_path = f'{FLAGS.data_path}/{FLAGS.simulation_dir}'
    with gfile.GFile(f'{simulation_path}/{FLAGS.data_file}', 'r') as f:
        df = pd.read_csv(f)

    # Split this into train and validate
    rng = np.random.default_rng(FLAGS.data_seed)
    users = np.unique(df['user'])
    users = rng.permutation(users)

    n_users = users.shape[0]
    n_train_users = int(n_users / 2)

    users_train = users[:n_train_users]
    users_val = users[n_train_users:]
    assert users_val.shape[0] + users_train.shape[0] == n_users

    df_tr = df.query('user in @users_train').copy()
    df_val = df.query('user in @users_val').copy()

    a_tr = df_tr['rec'].to_numpy()
    m_tr = df_tr[['diversity', 'rating']].to_numpy()
    y_tr = df_tr['ltr'].to_numpy()
    t_tr = np.ones_like(a_tr)

    a_val = df_val['rec'].to_numpy()
    m_val = df_val[['diversity', 'rating']].to_numpy()
    y_val = df_val['ltr'].to_numpy()
    t_val = np.ones_like(a_val)

    model = train_proxy.LogisticReg()

    data_tr = {
        'a': a_tr,
        'm': m_tr,
        'y': y_tr,
        't': t_tr,
    }

    data_val = {
        'a': a_val,
        'm': m_val,
        'y': y_val,
        't': t_val,
    }

    init_params = train_proxy.initialize_params(model, mdim=2, seed=FLAGS.seed)

    loss_tr = train_proxy.make_loss_func(model,
                                         data_tr,
                                         erm_weight=FLAGS.erm_weight,
                                         bias_lamb=FLAGS.bias_lamb,
                                         bias_norm=FLAGS.bias_norm)
    loss_val = train_proxy.make_loss_func(model,
                                          data_val,
                                          erm_weight=FLAGS.erm_weight,
                                          bias_lamb=FLAGS.bias_lamb,
                                          bias_norm=FLAGS.bias_norm)

    _, checkpoints = train_proxy.train(loss_tr,
                                       init_params,
                                       validation_loss=loss_val,
                                       lr=FLAGS.learning_rate,
                                       nsteps=FLAGS.nsteps,
                                       tol=FLAGS.tol,
                                       verbose=True,
                                       log=True)

    return checkpoints