示例#1
0
def run_test(seqs, label_seqs, sess, preds_T, input_PHs, label_PHs, mask_PHs,
             seq_length_PH, loss_T, options):
    all_losses = []
    all_preds = []
    all_labels = []
    batch_size = options['batch_size']
    for idx in xrange(len(label_seqs) / batch_size):
        batch_x = seqs[idx * batch_size:(idx + 1) * batch_size]
        batch_y = label_seqs[idx * batch_size:(idx + 1) * batch_size]
        inputs, _, masks, seq_length = mime_util.st_preprocess_hf_aux(
            batch_x, options)
        preds, loss = sess.run(
            [preds_T, loss_T],
            feed_dict={
                input_PHs[0]: inputs[0],
                input_PHs[1]: inputs[1],
                input_PHs[2]: inputs[2],
                mask_PHs[0]: masks[0],
                mask_PHs[1]: masks[1],
                mask_PHs[2]: masks[2],
                label_PHs[-1]: batch_y,
                seq_length_PH: seq_length,
            })
        all_losses.append(loss)
        all_preds.extend(list(preds))
        all_labels.extend(batch_y)
    auc = roc_auc_score(all_labels, all_preds)
    aucpr = average_precision_score(all_labels, all_preds)
    accuracy = (np.array(all_labels) == np.squeeze(
        binarize(np.array(all_preds).reshape(-1, 1), threshold=.5))).mean()
    return np.mean(all_losses), auc, aucpr
示例#2
0
def train(
    input_path='',
    batch_size=100,
    num_iter=100,
    eval_period=10,
    num_eval=100,
    rnn_size=256,
    output_size=1,
    learning_rate=1e-3,
    output_path='',
    random_seed=1234,
    split_seed=1234,
    emb_activation='sigmoid',
    order_activation='sigmoid',
    visit_activation='sigmoid',
    num_dx=100,
    num_rx=100,
    num_pr=100,
    dx_emb_size=128,
    rx_emb_size=128,
    pr_emb_size=128,
    dxobj_emb_size=128,
    visit_emb_size=128,
    max_dx_per_visit=29,
    max_rx_per_dx=17,
    max_pr_per_dx=10,
    regularize=1e-3,
    aux_lambda=0.1,
    min_threshold=5,
    max_threshold=150,
    train_ratio=1.0,
    association_threshold=0.0,
):
    options = locals().copy()
    input_PHs, label_PHs, mask_PHs, loss_Ts, seq_length_PH, preds_T = build_model(
        options)

    all_vars = tf.trainable_variables()
    L2_loss = tf.constant(0.0, dtype=tf.float32)
    for var in all_vars:
        if len(var.shape) < 2:
            continue
        L2_loss += tf.reduce_sum(var**2)

    optimizer = tf.train.AdamOptimizer(learning_rate=options['learning_rate'])
    loss_T = options['aux_lambda'] * (loss_Ts[0] + loss_Ts[1] +
                                      loss_Ts[2]) + loss_Ts[3]
    minimize_op = optimizer.minimize(loss_T + regularize * L2_loss)
    train_seqs, train_labels, valid_seqs, valid_labels, test_seqs, test_labels = mime_util.load_data(
        options['input_path'],
        min_threshold=options['min_threshold'],
        max_threshold=options['max_threshold'],
        seed=options['split_seed'],
        train_ratio=options['train_ratio'],
        association_threshold=options['association_threshold'])

    saver = tf.train.Saver(max_to_keep=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        best_valid_loss = 100000.0
        best_test_loss = 100000.0
        best_valid_auc = 0.0
        best_test_auc = 0.0
        best_valid_aucpr = 0.0
        best_test_aucpr = 0.0
        for train_iter in xrange(options['num_iter'] + 1):
            batch_x, batch_y = mime_util.sample_batch(train_seqs, train_labels,
                                                      options['batch_size'])
            inputs, labels, masks, seq_length = mime_util.st_preprocess_hf_aux(
                batch_x, options)
            _, preds, losses = sess.run(
                [minimize_op, preds_T, loss_Ts],
                feed_dict={
                    input_PHs[0]: inputs[0],
                    input_PHs[1]: inputs[1],
                    input_PHs[2]: inputs[2],
                    mask_PHs[0]: masks[0],
                    mask_PHs[1]: masks[1],
                    mask_PHs[2]: masks[2],
                    label_PHs[0]: labels[0],
                    label_PHs[1]: labels[1],
                    label_PHs[2]: labels[2],
                    label_PHs[3]: batch_y,
                    seq_length_PH: seq_length,
                })

            if train_iter > 0 and train_iter % options['eval_period'] == 0:
                valid_loss, valid_auc, valid_aucpr = run_test(
                    valid_seqs, valid_labels, sess, preds_T, input_PHs,
                    label_PHs, mask_PHs, seq_length_PH, loss_Ts[-1], options)
                if valid_loss < best_valid_loss:
                    test_loss, test_auc, test_aucpr = run_test(
                        test_seqs, test_labels, sess, preds_T, input_PHs,
                        label_PHs, mask_PHs, seq_length_PH, loss_Ts[-1],
                        options)
                    best_valid_loss = valid_loss
                    best_valid_auc = valid_auc
                    best_valid_aucpr = valid_aucpr
                    best_test_loss = test_loss
                    best_test_auc = test_auc
                    best_test_aucpr = test_aucpr
                    savePath = saver.save(sess,
                                          output_path + '/r' +
                                          str(random_seed) + 's' +
                                          str(split_seed) + '/model',
                                          global_step=train_iter)
                print('round:%d, valid_loss:%f, valid_auc:%f, valid_aucpr:%f' %
                      (train_iter, valid_loss, valid_auc, valid_aucpr))
        return best_valid_loss, best_test_loss, best_valid_auc, best_test_auc, best_valid_aucpr, best_test_aucpr