def test_match_sklearn(self): clf_sklearn = linear_model.LogisticRegression(penalty='none') clf_sklearn.fit(self.data_x, self.data_y) self.assertAlmostEqual(clf_sklearn.intercept_[0], -0.0271, places=3) self.assertAlmostEqual(clf_sklearn.coef_[0][0], 1.0987, places=3) self.assertAlmostEqual(clf_sklearn.coef_[0][1], -1.0216, places=3) model = train_proxy.LogisticReg() erm_loss = train_proxy.make_loss_func(model, self.data, erm_weight=1., bias_lamb=0.) init_params = train_proxy.initialize_params(model, mdim=self.data_x.shape[1], seed=0) erm_params, _ = train_proxy.train(erm_loss, init_params, lr=1., nsteps=1000) b, w = jax.tree_leaves(erm_params) self.assertAlmostEqual(w[0].item(), 1.0988, places=3) self.assertAlmostEqual(w[1].item(), -1.0216, places=3) self.assertAlmostEqual(b.item(), -0.0274, places=3)
def test_policy_bias_regularization(self): model = train_proxy.LogisticReg() mix_loss = train_proxy.make_loss_func(model, self.data, erm_weight=1., bias_lamb=10.) init_params = train_proxy.initialize_params(model, mdim=self.data_x.shape[1], seed=0) mix_params, _ = train_proxy.train(mix_loss, init_params, lr=1., nsteps=1000) b, w = jax.tree_leaves(mix_params) self.assertAlmostEqual(w[0].item(), 1.4478, places=3) self.assertAlmostEqual(w[1].item(), -1.2915, places=3) self.assertAlmostEqual(b.item(), 0.1693, places=3)
def test_l2_regularization(self): model = train_proxy.LogisticReg() erm_loss_reg = train_proxy.make_loss_func(model, self.data, erm_weight=1., l2_lamb=10., bias_lamb=0.) init_params = train_proxy.initialize_params(model, mdim=self.data_x.shape[1], seed=0) erm_params_reg, _ = train_proxy.train(erm_loss_reg, init_params, lr=1., nsteps=1000) b, w = jax.tree_leaves(erm_params_reg) self.assertAlmostEqual(w[0].item(), 0.0030, places=3) self.assertAlmostEqual(w[1].item(), -0.0028, places=3) self.assertAlmostEqual(b.item(), 0.0158, places=3)
def load_and_train(): """Load data from file and return checkpoints from training.""" simulation_path = f'{FLAGS.data_path}/{FLAGS.simulation_dir}' with gfile.GFile(f'{simulation_path}/{FLAGS.data_file}', 'r') as f: df = pd.read_csv(f) # Split this into train and validate rng = np.random.default_rng(FLAGS.data_seed) users = np.unique(df['user']) users = rng.permutation(users) n_users = users.shape[0] n_train_users = int(n_users / 2) users_train = users[:n_train_users] users_val = users[n_train_users:] assert users_val.shape[0] + users_train.shape[0] == n_users df_tr = df.query('user in @users_train').copy() df_val = df.query('user in @users_val').copy() a_tr = df_tr['rec'].to_numpy() m_tr = df_tr[['diversity', 'rating']].to_numpy() y_tr = df_tr['ltr'].to_numpy() t_tr = np.ones_like(a_tr) a_val = df_val['rec'].to_numpy() m_val = df_val[['diversity', 'rating']].to_numpy() y_val = df_val['ltr'].to_numpy() t_val = np.ones_like(a_val) model = train_proxy.LogisticReg() data_tr = { 'a': a_tr, 'm': m_tr, 'y': y_tr, 't': t_tr, } data_val = { 'a': a_val, 'm': m_val, 'y': y_val, 't': t_val, } init_params = train_proxy.initialize_params(model, mdim=2, seed=FLAGS.seed) loss_tr = train_proxy.make_loss_func(model, data_tr, erm_weight=FLAGS.erm_weight, bias_lamb=FLAGS.bias_lamb, bias_norm=FLAGS.bias_norm) loss_val = train_proxy.make_loss_func(model, data_val, erm_weight=FLAGS.erm_weight, bias_lamb=FLAGS.bias_lamb, bias_norm=FLAGS.bias_norm) _, checkpoints = train_proxy.train(loss_tr, init_params, validation_loss=loss_val, lr=FLAGS.learning_rate, nsteps=FLAGS.nsteps, tol=FLAGS.tol, verbose=True, log=True) return checkpoints