def run_test(seqs, label_seqs, sess, preds_T, input_PHs, label_PHs, mask_PHs, seq_length_PH, loss_T, options): all_losses = [] all_preds = [] all_labels = [] batch_size = options['batch_size'] for idx in xrange(len(label_seqs) / batch_size): batch_x = seqs[idx * batch_size:(idx + 1) * batch_size] batch_y = label_seqs[idx * batch_size:(idx + 1) * batch_size] inputs, _, masks, seq_length = mime_util.st_preprocess_hf_aux( batch_x, options) preds, loss = sess.run( [preds_T, loss_T], feed_dict={ input_PHs[0]: inputs[0], input_PHs[1]: inputs[1], input_PHs[2]: inputs[2], mask_PHs[0]: masks[0], mask_PHs[1]: masks[1], mask_PHs[2]: masks[2], label_PHs[-1]: batch_y, seq_length_PH: seq_length, }) all_losses.append(loss) all_preds.extend(list(preds)) all_labels.extend(batch_y) auc = roc_auc_score(all_labels, all_preds) aucpr = average_precision_score(all_labels, all_preds) accuracy = (np.array(all_labels) == np.squeeze( binarize(np.array(all_preds).reshape(-1, 1), threshold=.5))).mean() return np.mean(all_losses), auc, aucpr
def train( input_path='', batch_size=100, num_iter=100, eval_period=10, num_eval=100, rnn_size=256, output_size=1, learning_rate=1e-3, output_path='', random_seed=1234, split_seed=1234, emb_activation='sigmoid', order_activation='sigmoid', visit_activation='sigmoid', num_dx=100, num_rx=100, num_pr=100, dx_emb_size=128, rx_emb_size=128, pr_emb_size=128, dxobj_emb_size=128, visit_emb_size=128, max_dx_per_visit=29, max_rx_per_dx=17, max_pr_per_dx=10, regularize=1e-3, aux_lambda=0.1, min_threshold=5, max_threshold=150, train_ratio=1.0, association_threshold=0.0, ): options = locals().copy() input_PHs, label_PHs, mask_PHs, loss_Ts, seq_length_PH, preds_T = build_model( options) all_vars = tf.trainable_variables() L2_loss = tf.constant(0.0, dtype=tf.float32) for var in all_vars: if len(var.shape) < 2: continue L2_loss += tf.reduce_sum(var**2) optimizer = tf.train.AdamOptimizer(learning_rate=options['learning_rate']) loss_T = options['aux_lambda'] * (loss_Ts[0] + loss_Ts[1] + loss_Ts[2]) + loss_Ts[3] minimize_op = optimizer.minimize(loss_T + regularize * L2_loss) train_seqs, train_labels, valid_seqs, valid_labels, test_seqs, test_labels = mime_util.load_data( options['input_path'], min_threshold=options['min_threshold'], max_threshold=options['max_threshold'], seed=options['split_seed'], train_ratio=options['train_ratio'], association_threshold=options['association_threshold']) saver = tf.train.Saver(max_to_keep=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) best_valid_loss = 100000.0 best_test_loss = 100000.0 best_valid_auc = 0.0 best_test_auc = 0.0 best_valid_aucpr = 0.0 best_test_aucpr = 0.0 for train_iter in xrange(options['num_iter'] + 1): batch_x, batch_y = mime_util.sample_batch(train_seqs, train_labels, options['batch_size']) inputs, labels, masks, seq_length = mime_util.st_preprocess_hf_aux( batch_x, options) _, preds, losses = sess.run( [minimize_op, preds_T, loss_Ts], feed_dict={ input_PHs[0]: inputs[0], input_PHs[1]: inputs[1], input_PHs[2]: inputs[2], mask_PHs[0]: masks[0], mask_PHs[1]: masks[1], mask_PHs[2]: masks[2], label_PHs[0]: labels[0], label_PHs[1]: labels[1], label_PHs[2]: labels[2], label_PHs[3]: batch_y, seq_length_PH: seq_length, }) if train_iter > 0 and train_iter % options['eval_period'] == 0: valid_loss, valid_auc, valid_aucpr = run_test( valid_seqs, valid_labels, sess, preds_T, input_PHs, label_PHs, mask_PHs, seq_length_PH, loss_Ts[-1], options) if valid_loss < best_valid_loss: test_loss, test_auc, test_aucpr = run_test( test_seqs, test_labels, sess, preds_T, input_PHs, label_PHs, mask_PHs, seq_length_PH, loss_Ts[-1], options) best_valid_loss = valid_loss best_valid_auc = valid_auc best_valid_aucpr = valid_aucpr best_test_loss = test_loss best_test_auc = test_auc best_test_aucpr = test_aucpr savePath = saver.save(sess, output_path + '/r' + str(random_seed) + 's' + str(split_seed) + '/model', global_step=train_iter) print('round:%d, valid_loss:%f, valid_auc:%f, valid_aucpr:%f' % (train_iter, valid_loss, valid_auc, valid_aucpr)) return best_valid_loss, best_test_loss, best_valid_auc, best_test_auc, best_valid_aucpr, best_test_aucpr