示例#1
0
         #                                time.time() - t0))
 """Evaluation"""
 if not load_model:
     # load_model = save_model
     saver.restore(sess, save_model + '/{}-{}'.format(task, i))
 test_datas = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tte, y_ph:yte}
 loss, grads_and_vars = inference.build_loss_and_gradients([x_ph])
 grad, varib = grads_and_vars[0]
 test_grad = sess.run(grad, feed_dict=test_datas)
 xte += adv_scale * np.sign(test_grad)
 f1t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr1t}
 f0t = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: tr0t}
 """Data elimination"""
 if drop_type != "random":
     # cpvr and fpvr computation with test data
     pred_y_0_samples, pred_y_1_samples = get_y0_y1(sess, y_post, f0t, f1t, \
             shape=yte.shape, L=100, verbose=False, task=task, get_sample=True)
     loss_cpvr_0 = (1-tte) * np.var(pred_y_1_samples, axis=0)
     loss_cpvr_1 = tte * np.var(pred_y_0_samples, axis=0)
     loss_cpvr = (loss_cpvr_0 + loss_cpvr_1).squeeze() # (m, 1)
     loss_fpvr_0 = tte * np.var(pred_y_1_samples, axis=0)
     loss_fpvr_1 = (1-tte) * np.var(pred_y_0_samples, axis=0)
     loss_fpvr = (loss_fpvr_0 + loss_fpvr_1).squeeze() # (m, 1)
 if drop_type == "kl":
     alpha = 1e-8
     KL_t0t = 0.5 * tf.reduce_mean( tf.divide(tf.square(muq_t0 - mu_z), sigma_z) + tf.divide(tf.square(sigmaq_t0),sigma_z) + tf.log(tf.square(sigma_z) + alpha) - tf.log(tf.square(sigmaq_t0) + alpha*tf.ones_like(mu_z)) - tf.ones_like(mu_z) , axis=1)
     KL_t1t = 0.5 * tf.reduce_mean( tf.divide(tf.square(muq_t1 - mu_z), sigma_z) + tf.divide(tf.square(sigmaq_t1),sigma_z) + tf.log(tf.square(sigma_z) + alpha) - tf.log(tf.square(sigmaq_t1) + alpha*tf.ones_like(mu_z)) - tf.ones_like(mu_z) , axis=1)
     KL_t0t = sess.run(KL_t0t, feed_dict=f0t)
     KL_t1t = sess.run(KL_t1t, feed_dict=f1t)
     KLt = KL_t0t + KL_t1t
     indices = np.argsort(KLt)
     # raise NotImplementedError
示例#2
0
                                         t_ph: tva,
                                         y_ph: yva
                                     })
                # Early stopping prevent overfitting
                if logpvalid >= best_logpvalid:
                    print(
                        'Improved Validation Bound, Old: {:0.3f}, New: {:0.3f}'
                        .format(best_logpvalid, logpvalid))
                    best_logpvalid = logpvalid
                    # saving model
                    saver.save(sess, 'models/ihdp')

            if epoch % args.print_every == 0:
                y0, y1 = get_y0_y1(sess,
                                   y_post,
                                   f0,
                                   f1,
                                   shape=yalltr.shape,
                                   L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_train = evaluator_train.calc_stats(y1, y0)
                rmses_train = evaluator_train.y_errors(y0, y1)

                y0, y1 = get_y0_y1(sess,
                                   y_post,
                                   f0t,
                                   f1t,
                                   shape=yte.shape,
                                   L=1)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                score_test = evaluator_test.calc_stats(y1, y0)
示例#3
0
            if epoch % args.earl == 0 or epoch == (n_epoch - 1):
                # logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):],
                #                                             t_ph: tva, y_ph: yva})
                logpvalid = sess.run(logp_valid,feed_dict = {xi_ph_bin: xiva[:, 0:len(binfeats)], xi_ph_cont: xiva[:, len(binfeats):],
                                                             ti_ph: tiva, yi_ph: yiva,
                                                             xj_ph_bin: xjva[:, 0:len(binfeats)], xj_ph_cont: xjva[:, len(binfeats):],
                                                             tj_ph: tjva, yj_ph: yjva})
                if logpvalid >= best_logpvalid:
                    print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid, logpvalid))
                    best_logpvalid = logpvalid
                    saver.save(sess, 'models/pair-ihdp')

            if epoch % args.print_every == 0:
                # predict for the whole training
                yi0, yi1 = get_y0_y1(sess, yi_post, fi0, fi1, shape=yialltr.shape, L=1)
                yi0, yi1 = yi0 * yis + yim, yi1 * yis + yim

                yj0, yj1 = get_y0_y1(sess, yj_post, fj0, fj1, shape=yjalltr.shape, L=1)
                yj0, yj1 = yj0 * yjs + yjm, yj1 * yjs + yjm

                # I need to figure out the average of each node, so we need to retrieve which nodes are in a certain pair.
                # We have the list of node idx, so we create a data structure s.t. we can group them by the pos and calculate average

                avg_yi0, avg_yi1 = average_y(yi0,nodei_map_alltr,pos_i_alltr), average_y(yi1,nodei_map_alltr,pos_i_alltr)
                avg_yj0, avg_yj1 = average_y(yj0,nodej_map_alltr,pos_j_alltr), average_y(yj1,nodej_map_alltr,pos_j_alltr)

                score_train_i = evaluatori_train.calc_stats(avg_yi1, avg_yi0)
                rmses_train_i = evaluatori_train.y_errors(avg_yi0, avg_yi1)

                score_train_j = evaluatorj_train.calc_stats(avg_yj1, avg_yj0)