def train(CFR, sess, train_step, D, I_valid, D_test, logfile, i_exp): """ Trains a CFR model on supplied data """ ''' Train/validation split ''' n = D['x'].shape[0] I = range(n); I_train = list(set(I)-set(I_valid)) n_train = len(I_train) ''' Compute treatment probability''' p_treated = np.mean(D['t'][I_train,:]) ''' Set up loss feed_dicts''' dict_factual = {CFR.x: D['x'][I_train,:], CFR.t: D['t'][I_train,:], CFR.y_: D['yf'][I_train,:], \ CFR.do_in: 1.0, CFR.do_out: 1.0, CFR.r_alpha: FLAGS.p_alpha, \ CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated} if FLAGS.val_part > 0: dict_valid = {CFR.x: D['x'][I_valid,:], CFR.t: D['t'][I_valid,:], CFR.y_: D['yf'][I_valid,:], \ CFR.do_in: 1.0, CFR.do_out: 1.0, CFR.r_alpha: FLAGS.p_alpha, \ CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated} if D['HAVE_TRUTH']: dict_cfactual = {CFR.x: D['x'][I_train,:], CFR.t: 1-D['t'][I_train,:], CFR.y_: D['ycf'][I_train,:], \ CFR.do_in: 1.0, CFR.do_out: 1.0} ''' Initialize TensorFlow variables ''' sess.run(tf.global_variables_initializer()) ''' Set up for storing predictions ''' preds_train = [] preds_test = [] ''' Compute losses ''' losses = [] obj_loss, f_error, imb_err = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist],\ feed_dict=dict_factual) cf_error = np.nan if D['HAVE_TRUTH']: cf_error = sess.run(CFR.pred_loss, feed_dict=dict_cfactual) valid_obj = np.nan; valid_imb = np.nan; valid_f_error = np.nan; if FLAGS.val_part > 0: valid_obj, valid_f_error, valid_imb = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist],\ feed_dict=dict_valid) losses.append([obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj]) objnan = False reps = [] reps_test = [] ''' Train for multiple iterations ''' for i in range(FLAGS.iterations): ''' Fetch sample ''' I = random.sample(range(0, n_train), FLAGS.batch_size) x_batch = D['x'][I_train,:][I,:] t_batch = D['t'][I_train,:][I] y_batch = D['yf'][I_train,:][I] if __DEBUG__: M = sess.run(cfr.pop_dist(CFR.x, CFR.t), feed_dict={CFR.x: x_batch, CFR.t: t_batch}) log(logfile, 'Median: %.4g, Mean: %.4f, Max: %.4f' % (np.median(M.tolist()), np.mean(M.tolist()), np.amax(M.tolist()))) ''' Do one step of gradient descent ''' if not objnan: sess.run(train_step, feed_dict={CFR.x: x_batch, CFR.t: t_batch, \ CFR.y_: y_batch, CFR.do_in: FLAGS.dropout_in, CFR.do_out: FLAGS.dropout_out, \ CFR.r_alpha: FLAGS.p_alpha, CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated}) ''' Project variable selection weights ''' if FLAGS.varsel: wip = simplex_project(sess.run(CFR.weights_in[0]), 1) sess.run(CFR.projection, feed_dict={CFR.w_proj: wip}) ''' Compute loss every N iterations ''' if i % FLAGS.output_delay == 0 or i==FLAGS.iterations-1: obj_loss,f_error,imb_err = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist], feed_dict=dict_factual) rep = sess.run(CFR.h_rep_norm, feed_dict={CFR.x: D['x'], CFR.do_in: 1.0}) rep_norm = np.mean(np.sqrt(np.sum(np.square(rep), 1))) cf_error = np.nan if D['HAVE_TRUTH']: cf_error = sess.run(CFR.pred_loss, feed_dict=dict_cfactual) valid_obj = np.nan; valid_imb = np.nan; valid_f_error = np.nan; if FLAGS.val_part > 0: valid_obj, valid_f_error, valid_imb = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist], feed_dict=dict_valid) losses.append([obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj]) loss_str = str(i) + '\tObj: %.3f,\tF: %.3f,\tCf: %.3f,\tImb: %.2g,\tVal: %.3f,\tValImb: %.2g,\tValObj: %.2f' \ % (obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj) if FLAGS.loss == 'log': y_pred = sess.run(CFR.output, feed_dict={CFR.x: x_batch, \ CFR.t: t_batch, CFR.do_in: 1.0, CFR.do_out: 1.0}) y_pred = 1.0*(y_pred > 0.5) acc = 100*(1 - np.mean(np.abs(y_batch - y_pred))) loss_str += ',\tAcc: %.2f%%' % acc log(logfile, loss_str) if np.isnan(obj_loss): log(logfile,'Experiment %d: Objective is NaN. Skipping.' % i_exp) objnan = True ''' Compute predictions every M iterations ''' if (FLAGS.pred_output_delay > 0 and i % FLAGS.pred_output_delay == 0) or i==FLAGS.iterations-1: y_pred_f = sess.run(CFR.output, feed_dict={CFR.x: D['x'], \ CFR.t: D['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) y_pred_cf = sess.run(CFR.output, feed_dict={CFR.x: D['x'], \ CFR.t: 1-D['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) preds_train.append(np.concatenate((y_pred_f, y_pred_cf),axis=1)) if D_test is not None: y_pred_f_test = sess.run(CFR.output, feed_dict={CFR.x: D_test['x'], \ CFR.t: D_test['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) y_pred_cf_test = sess.run(CFR.output, feed_dict={CFR.x: D_test['x'], \ CFR.t: 1-D_test['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) preds_test.append(np.concatenate((y_pred_f_test, y_pred_cf_test),axis=1)) if FLAGS.save_rep and i_exp == 1: reps_i = sess.run([CFR.h_rep], feed_dict={CFR.x: D['x'], \ CFR.do_in: 1.0, CFR.do_out: 0.0}) reps.append(reps_i) if D_test is not None: reps_test_i = sess.run([CFR.h_rep], feed_dict={CFR.x: D_test['x'], \ CFR.do_in: 1.0, CFR.do_out: 0.0}) reps_test.append(reps_test_i) return losses, preds_train, preds_test, reps, reps_test
def run_descriptive_stats(self, x_batch, t_batch): m_statistics = self.sess.run(cfr.pop_dist(self.CFR.x, self.CFR.t), feed_dict={self.CFR.x: x_batch, self.CFR.t: t_batch}) return m_statistics