def train(hidden_size, learning_rate, l2_regularization):

    # train_set = np.load('../../Trajectory_generate/dataset_file/HF_train_.npy').reshape(-1, 6, 30)
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_validate_.npy').reshape(-1, 6, 30)
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_test_.npy').reshape(-1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    train_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy').reshape(-1, 13, 40)
    test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 10

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64

    epochs = 50

    # hidden_size = 2 ** (int(hidden_size))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization

    print('previous_visit---{}---predicted_visit----{}-'.format(previous_visit, predicted_visit))

    print('hidden_size{}-----learning_rate{}----l2_regularization{}----'.format(hidden_size, learning_rate, l2_regularization))

    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)

    logged = set()

    max_loss = 0.01
    max_pace = 0.0001

    mse_loss = 0
    count = 0
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size)
        input_x_train = input_train[:, :, 1:]
        input_t_train = input_train[:, :, 0]
        batch = input_x_train.shape[0]
        with tf.GradientTape() as tape:
            predicted_trajectory = np.zeros(shape=(batch, 0, feature_dims))
            for predicted_visit_ in range(predicted_visit):
                sequence_time_last_time = input_x_train[:, previous_visit+predicted_visit_-1, :]  # y_j
                for previous_visit_ in range(previous_visit+predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                    encode_c, encode_h = encode_share([sequence_time, encode_c, encode_h])
                context_state = encode_h  # h_j from 1 to j

                input_decode = tf.concat((sequence_time_last_time, context_state), axis=1)  # y_j and h_j
                if predicted_visit_ == 0:
                    decode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                    decode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))

                input_t = tf.reshape(input_t_train[:, previous_visit+predicted_visit_], [-1, 1])
                predicted_next_sequence, decode_c, decode_h = decoder_share([input_decode, input_t, decode_c, decode_h])
                predicted_next_sequence = tf.reshape(predicted_next_sequence, [batch, -1, feature_dims])
                predicted_trajectory = tf.concat((predicted_trajectory, predicted_next_sequence), axis=1)

            mse_loss = tf.reduce_mean(tf.keras.losses.mse(input_x_train[:, previous_visit: previous_visit+predicted_visit, :], predicted_trajectory))

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                mse_loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                mse_loss += tf.keras. regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            gradient = tape.gradient(mse_loss, variables)
            optimizer.apply_gradients(zip(gradient, variables))

            if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
                logged.add(train_set.epoch_completed)
                loss_pre = mse_loss
                mse_loss = tf.reduce_mean(tf.keras.losses.mse(input_x_train[:, previous_visit: previous_visit + predicted_visit, :], predicted_trajectory))
                loss_diff = loss_pre - mse_loss
                if mse_loss > max_loss:
                    count = 0

                else:
                    if loss_diff > max_pace:
                        count = 0
                    else:
                        count += 1
                if count > 9:
                    break

                input_x_test = test_set[:, :, 1:]
                input_t_test = test_set[:, :, 0]
                batch_test = input_x_test.shape[0]
                predicted_trajectory_test = np.zeros(shape=[batch_test, 0, feature_dims])
                for predicted_visit_ in range(predicted_visit):
                    if predicted_visit_ == 0:
                        sequence_time_last_time_test = input_x_test[:, predicted_visit_+previous_visit-1, :]
                    for previous_visit_ in range(previous_visit):
                        sequence_time_test = input_x_test[:, previous_visit_, :]
                        if previous_visit_ == 0:
                            encode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                            encode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        encode_c_test, encode_h_test = encode_share([sequence_time_test, encode_c_test, encode_h_test])

                    if predicted_visit_ != 0:
                        for i in range(predicted_visit_):
                            sequence_input_t = predicted_trajectory_test[:, i, :]
                            encode_c_test, encode_h_test = encode_share([sequence_input_t, encode_c_test, encode_h_test])

                    context_state = encode_h_test

                    if predicted_visit_ == 0:
                        decode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        decode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                    input_decode_test = tf.concat((sequence_time_last_time_test, context_state), axis=1)
                    input_t = tf.reshape(input_t_test[:, previous_visit+predicted_visit_], [-1, 1])
                    predicted_next_sequence_test, decode_c_test, decode_h_test = decoder_share([input_decode_test, input_t, decode_c_test, decode_h_test])
                    sequence_time_last_time_test = predicted_next_sequence_test  # feed the generated sequence into next state
                    predicted_next_sequence_test = tf.reshape(predicted_next_sequence_test, [batch_test, -1, feature_dims])
                    predicted_trajectory_test = tf.concat((predicted_trajectory_test, predicted_next_sequence_test),
                                                          axis=1)
                mse_loss_predicted = tf.reduce_mean(tf.keras.losses.mse(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))
                mae_predicted = tf.reduce_mean(tf.keras.losses.mae(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))
                r_value_all = []
                for patient in range(batch_test):
                    r_value = 0.0
                    for feature in range(feature_dims):
                        x_ = input_x_test[patient, previous_visit:, feature].reshape(predicted_visit, 1)
                        y_ = predicted_trajectory_test[patient, :, feature].numpy().reshape(predicted_visit, 1)
                        r_value += DynamicTimeWarping(x_, y_)
                    r_value_all.append(r_value / 29.0)

                print('------epoch{}------mse_loss{}----predicted_mse-----{}---predicted_r_value---{}--'
                      '-count  {}'.format(train_set.epoch_completed,
                                          mse_loss, mse_loss_predicted,
                                          np.mean(r_value_all),
                                          count))
                # r_value_all = []
                # p_value_all = []
                # r_value_spearman = []
                # r_value_kendalltau = []
                # for visit in range(predicted_visit):
                #     for feature in range(feature_dims):
                #         x_ = input_x_test[:, previous_visit+visit, feature]
                #         y_ = predicted_trajectory_test[:, visit, feature]
                #         r_value_ = stats.pearsonr(x_, y_)
                #         r_value_spearman_ = stats.spearmanr(x_, y_)
                #         r_value_kendalltau_ = stats.kendalltau(x_, y_)
                #         if not np.isnan(r_value_[0]):
                #             r_value_all.append(np.abs(r_value_[0]))
                #             p_value_all.append(np.abs(r_value_[1]))
                #         if not np.isnan(r_value_spearman_[0]):
                #             r_value_spearman.append(np.abs(r_value_spearman_[0]))
                #         if not np.isnan(r_value_kendalltau_[0]):
                #             r_value_kendalltau.append(np.abs(r_value_kendalltau_[0]))
                # print('------epoch{}------mse_loss{}----predicted_mse-----{}---predicted_r_value---{}--'
                #       'r_value_spearman---{}---r_value_kendalltau---{}--count  {}'.format(train_set.epoch_completed,
                #                                                                           mse_loss, mse_loss_predicted,
                #                                                                           np.mean(r_value_all),
                #                                                                           np.mean(r_value_spearman),
                #                                                                           np.mean(r_value_kendalltau),
                #                                                                           count))
    tf.compat.v1.reset_default_graph()
    return mse_loss_predicted, mae_predicted, np.mean(r_value_all), np.mean(p_value_all)
示例#2
0
def train(hidden_size, z_dims, l2_regularization, learning_rate, kl_imbalance,
          reconstruction_imbalance, generated_mse_imbalance):
    # train_set = np.load("../../Trajectory_generate/dataset_file/HF_train_.npy").reshape(-1, 6, 30)
    # test_set = np.load("../../Trajectory_generate/dataset_file/HF_test_.npy").reshape(-1, 6, 30)
    # test_set = np.load("../../Trajectory_generate/dataset_file/HF_validate_.npy").reshape(-1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)[:, :, 1:]
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)[:, :, 1:]
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)[:, :, 1:]

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    train_set = np.load(
        '../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy'
    ).reshape(-1, 13, 40)
    test_set = np.load(
        '../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy'
    ).reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 10

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50
    #
    # hidden_size = 2 ** (int(hidden_size))
    # z_dims = 2 ** (int(z_dims))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # kl_imbalance = 10 ** kl_imbalance
    # reconstruction_imbalance = 10 ** reconstruction_imbalance
    # generated_mse_imbalance = 10 ** generated_mse_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(
        previous_visit, predicted_visit))

    print(
        'hidden_size{}----z_dims{}------learning_rate{}----l2_regularization{}---'
        'kl_imbalance{}----reconstruction_imbalance '
        ' {}----generated_mse_imbalance{}----'.format(
            hidden_size, z_dims, learning_rate, l2_regularization,
            kl_imbalance, reconstruction_imbalance, generated_mse_imbalance))

    encode_share = Encoder(hidden_size=hidden_size)
    decode_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    prior_net = Prior(z_dims=z_dims)
    post_net = Post(z_dims=z_dims)

    logged = set()
    max_loss = 0.01
    max_pace = 0.0001
    loss = 0
    count = 0
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        input_x_train = tf.cast(input_train[:, :, 1:], dtype=tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)
        batch = input_x_train.shape[0]

        with tf.GradientTape() as tape:
            generated_trajectory = np.zeros(shape=[batch, 0, feature_dims])
            construct_trajectory = np.zeros(shape=[batch, 0, feature_dims])
            z_log_var_post_all = np.zeros(shape=[batch, 0, z_dims])
            z_mean_post_all = np.zeros(shape=[batch, 0, z_dims])
            z_log_var_prior_all = np.zeros(shape=[batch, 0, z_dims])
            z_mean_prior_all = np.zeros(shape=[batch, 0, z_dims])

            for predicted_visit_ in range(predicted_visit):
                sequence_last_time = input_x_train[:, predicted_visit_ +
                                                   previous_visit - 1, :]
                sequence_time_current_time = input_x_train[:,
                                                           predicted_visit_ +
                                                           previous_visit, :]

                for previous_visit_ in range(previous_visit +
                                             predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                    encode_c, encode_h = encode_share(
                        [sequence_time, encode_c, encode_h])
                context_state = encode_h
                z_prior, z_mean_prior, z_log_var_prior = prior_net(
                    context_state)  # h_i--> z_(i+1)
                encode_c, encode_h = encode_share(
                    [sequence_time_current_time, encode_c,
                     encode_h])  # h_(i+1)
                z_post, z_mean_post, z_log_var_post = post_net(
                    [context_state, encode_h])  # h_i, h_(i+1) --> z_(i+1)
                if predicted_visit_ == 0:
                    decode_c_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                    decode_c_reconstruct = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_reconstruct = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                input_t = tf.reshape(
                    input_t_train[:, previous_visit + predicted_visit_],
                    [-1, 1])
                construct_next_visit, decode_c_reconstruct, decode_h_reconstruct = decode_share(
                    [
                        z_post, context_state, sequence_last_time,
                        decode_c_reconstruct, decode_h_reconstruct, input_t
                    ])
                construct_next_visit = tf.reshape(construct_next_visit,
                                                  [batch, -1, feature_dims])
                construct_trajectory = tf.concat(
                    (construct_trajectory, construct_next_visit), axis=1)

                generated_next_visit, decode_c_generate, decode_h_generate = decode_share(
                    [
                        z_prior, context_state, sequence_last_time,
                        decode_c_generate, decode_h_generate, input_t
                    ])
                generated_next_visit = tf.reshape(generated_next_visit,
                                                  (batch, -1, feature_dims))
                generated_trajectory = tf.concat(
                    (generated_trajectory, generated_next_visit), axis=1)

                z_mean_prior_all = tf.concat(
                    (z_mean_prior_all,
                     tf.reshape(z_mean_prior, [batch, -1, z_dims])),
                    axis=1)
                z_mean_post_all = tf.concat(
                    (z_mean_post_all,
                     tf.reshape(z_mean_post, [batch, -1, z_dims])),
                    axis=1)
                z_log_var_prior_all = tf.concat(
                    (z_log_var_prior_all,
                     tf.reshape(z_log_var_prior, [batch, -1, z_dims])),
                    axis=1)
                z_log_var_post_all = tf.concat(
                    (z_log_var_post_all,
                     tf.reshape(z_log_var_post, [batch, -1, z_dims])),
                    axis=1)

            mse_reconstruction = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], construct_trajectory))
            mse_generate = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            std_post = tf.math.sqrt(tf.exp(z_log_var_post_all))
            std_prior = tf.math.sqrt(tf.exp(z_mean_prior_all))
            kl_loss_element = 0.5 * (
                2 * tf.math.log(tf.maximum(std_prior, 1e-9)) -
                2 * tf.math.log(tf.maximum(std_post, 1e-9)) +
                (tf.math.pow(std_post, 2) + tf.math.pow(
                    (z_mean_post_all - z_mean_prior_all), 2)) /
                tf.maximum(tf.math.pow(std_prior, 2), 1e-9) - 1)
            kl_loss_all = tf.reduce_mean(kl_loss_element)

            loss += mse_reconstruction * reconstruction_imbalance + kl_loss_all * kl_imbalance + mse_generate * generated_mse_imbalance

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in post_net.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in prior_net.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)
            tape.watch(variables)

            gradient = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradient, variables))

            if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
                logged.add(train_set.epoch_completed)
                loss_pre = mse_generate
                mse_reconstruction = tf.reduce_mean(
                    tf.keras.losses.mse(
                        input_x_train[:, previous_visit:previous_visit +
                                      predicted_visit, :],
                        construct_trajectory))
                mse_generate = tf.reduce_mean(
                    tf.keras.losses.mse(
                        input_x_train[:, previous_visit:previous_visit +
                                      predicted_visit, :],
                        generated_trajectory))
                kl_loss_all = tf.reduce_mean(
                    kl_loss(z_mean_post=z_mean_post_all,
                            z_mean_prior=z_mean_prior_all,
                            log_var_post=z_log_var_post_all,
                            log_var_prior=z_log_var_prior_all))
                loss = mse_reconstruction + mse_generate + kl_loss_all
                loss_diff = loss_pre - mse_generate

                if mse_generate > max_loss:
                    count = 0  # max_loss = 0.01

                else:
                    if loss_diff > max_pace:  # max_pace = 0.0001
                        count = 0
                    else:
                        count += 1

                if count > 9:
                    break

                input_test = test_set
                input_x_test = tf.cast(input_test[:, :, 1:], dtype=tf.float32)
                input_t_test = tf.cast(input_test[:, :, 0], tf.float32)
                batch_test = input_x_test.shape[0]
                generated_trajectory_test = np.zeros(
                    shape=[batch_test, 0, feature_dims])
                for predicted_visit_ in range(predicted_visit):

                    for previous_visit_ in range(previous_visit):
                        sequence_time_test = input_x_test[:,
                                                          previous_visit_, :]
                        if previous_visit_ == 0:
                            encode_c_test = tf.Variable(
                                tf.zeros(shape=(batch_test, hidden_size)))
                            encode_h_test = tf.Variable(
                                tf.zeros(shape=(batch_test, hidden_size)))

                        encode_c_test, encode_h_test = encode_share(
                            [sequence_time_test, encode_c_test, encode_h_test])

                    if predicted_visit_ != 0:
                        for i in range(predicted_visit_):
                            sequence_input_t = generated_trajectory_test[:,
                                                                         i, :]
                            encode_c_test, encode_h_test = encode_share([
                                sequence_input_t, encode_c_test, encode_h_test
                            ])

                    context_state_test = encode_h_test
                    z_prior_test, z_mean_prior_test, z_log_var_prior_test = prior_net(
                        context_state_test)

                    if predicted_visit_ == 0:
                        decode_c_generate_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        decode_h_generate_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        sequence_last_time_test = input_x_test[:,
                                                               predicted_visit_
                                                               +
                                                               previous_visit -
                                                               1, :]

                    input_t = tf.reshape(
                        input_t_test[:, previous_visit + predicted_visit_],
                        [-1, 1])
                    sequence_last_time_test, decode_c_generate_test, decode_h_generate_test = decode_share(
                        [
                            z_prior_test, context_state_test,
                            sequence_last_time_test, decode_c_generate_test,
                            decode_h_generate_test, input_t
                        ])
                    generated_next_visit_test = sequence_last_time_test
                    generated_next_visit_test = tf.reshape(
                        generated_next_visit_test,
                        [batch_test, -1, feature_dims])
                    generated_trajectory_test = tf.concat(
                        (generated_trajectory_test, generated_next_visit_test),
                        axis=1)

                mse_generate_test = tf.reduce_mean(
                    tf.keras.losses.mse(
                        input_x_test[:, previous_visit:previous_visit +
                                     predicted_visit, :],
                        generated_trajectory_test))
                mae_generate_test = tf.reduce_mean(
                    tf.keras.losses.mae(
                        input_x_test[:, previous_visit:previous_visit +
                                     predicted_visit, :],
                        generated_trajectory_test))
                # r_value_all = []
                # p_value_all = []
                # r_value_spearman_all = []
                # r_value_kendalltau_all = []
                # for visit in range(predicted_visit):
                #     for feature in range(feature_dims):
                #         x_ = input_x_test[:, previous_visit+visit, feature]
                #         y_ = generated_trajectory_test[:, visit, feature]
                #         r_value_pearson = stats.pearsonr(x_, y_)
                #         r_value_spearman = stats.spearmanr(x_, y_)
                #         r_value_kendalltau = stats.kendalltau(x_, y_)
                #         if not np.isnan(r_value_pearson[0]):
                #             r_value_all.append(np.abs(r_value_pearson[0]))
                #             p_value_all.append(np.abs(r_value_pearson[1]))
                #
                #         if not np.isnan(r_value_spearman[0]):
                #             r_value_spearman_all.append(np.abs(r_value_spearman[0]))
                #
                #         if not np.isnan(r_value_kendalltau[0]):
                #             r_value_kendalltau_all.append(np.abs(r_value_kendalltau[0]))
                r_value_all = []
                for patient in range(batch_test):
                    r_value = 0.0
                    for feature in range(feature_dims):
                        x_ = input_x_test[patient, previous_visit:,
                                          feature].numpy().reshape(
                                              predicted_visit, 1)
                        y_ = generated_trajectory_test[
                            patient, :,
                            feature].numpy().reshape(predicted_visit, 1)
                        r_value += DynamicTimeWarping(x_, y_)
                    r_value_all.append(r_value / 29.0)
                print(
                    "epoch  {}---train_mse_generate {}--train_reconstruct {}--train_kl "
                    "{}--test_mse {}--test_mae  {}----r_value {}--"
                    "---count {}---".format(train_set.epoch_completed,
                                            mse_generate, mse_reconstruction,
                                            kl_loss_all, mse_generate_test,
                                            mae_generate_test,
                                            np.mean(r_value_all), count))

                # print("epoch  {}---train_mse_generate {}--train_reconstruct {}--train_kl "
                #       "{}--test_mse {}--test_mae  {}----r_value {}--r_value_spearman---{}---"
                #       "r_value_kendalltau---{}------count {}-".format(train_set.epoch_completed,
                #                                                       mse_generate,
                #                                                       mse_reconstruction,
                #                                                       kl_loss_all,
                #                                                       mse_generate_test,
                #                                                       mae_generate_test,
                #                                                       np.mean(r_value_all),
                #                                                       np.mean(r_value_spearman_all),
                #                                                       np.mean(r_value_kendalltau_all),
                #                                                       count))
    tf.compat.v1.reset_default_graph()
    return mse_generate_test, mae_generate_test, np.mean(r_value_all)
def test():

    train_set = np.load(
        "../../Trajectory_generate/dataset_file/train_x_.npy").reshape(
            -1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    test_set = np.load(
        "../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(
            -1, 6, 60)

    previous_visit = 3
    predicted_visit = 3

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0

    hidden_size = 64
    z_dims = 64

    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    prior_net = Prior(z_dims=z_dims)
    post_net = Post(z_dims=z_dims)
    hawkes_process = HawkesProcess()

    checkpoint_encode_share = tf.train.Checkpoint(encode_share=encode_share)
    checkpoint_encode_share.restore(
        tf.train.latest_checkpoint('./model/encoder_share/46.ckpt'))

    checkpoint_decode_share = tf.train.Checkpoint(decoder_share=decoder_share)
    checkpoint_decode_share.restore(
        tf.train.latest_checkpoint('./model/decoder_share/46.ckpt'))

    checkpoint_post = tf.train.Checkpoint(post_net=post_net)
    checkpoint_post.restore(
        tf.train.latest_checkpoint('./model/checkpoint_post/46.ckpt'))

    checkpoint_prior = tf.train.Checkpoint(prior_net=prior_net)
    checkpoint_prior.restore(
        tf.train.latest_checkpoint('./model/checkpoint_prior/46.ckpt'))

    checkpoint_hawkes = tf.train.Checkpoint(hawkes_process=hawkes_process)
    checkpoint_hawkes.restore(
        tf.train.latest_checkpoint('./model/checkpoint_hawkes/46.ckpt'))

    input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
    input_t_test = tf.cast(test_set[:, :, 0], tf.float32)
    batch_test = input_x_test.shape[0]
    generated_trajectory_test = tf.zeros(shape=[batch_test, 0, feature_dims])
    for predicted_visit_ in range(predicted_visit):
        for previous_visit_ in range(previous_visit + predicted_visit_):
            sequence_time_test = input_x_test[:, previous_visit_, :]
            if previous_visit_ == 0:
                encode_c_test = tf.Variable(
                    tf.zeros(shape=[batch_test, hidden_size]))
                encode_h_test = tf.Variable(
                    tf.zeros(shape=[batch_test, hidden_size]))
            encode_c_test, encode_h_test = encode_share(
                [sequence_time_test, encode_c_test, encode_h_test])

        if predicted_visit_ != 0:
            for i in range(predicted_visit_):
                sequence_input_t = generated_trajectory_test[:, i, :]
                encode_c_test, encode_h_test = encode_share(
                    [sequence_input_t, encode_c_test, encode_h_test])

        context_state_test = encode_h_test
        z_prior_test, z_mean_prior_test, z_log_var_prior = prior_net(
            context_state_test)

        if predicted_visit_ == 0:
            decode_c_generate_test = tf.Variable(
                tf.zeros(shape=[batch_test, hidden_size]))
            decode_h_generate_test = tf.Variable(
                tf.zeros(shape=[batch_test, hidden_size]))
            sequence_last_time_test = input_x_test[:, predicted_visit_ +
                                                   previous_visit, :]

        current_time_index_shape_test = tf.ones(
            shape=[previous_visit + predicted_visit_])

        condition_intensity_test, likelihood_test = hawkes_process(
            [input_t_test, current_time_index_shape_test])

        sequence_next_visit_test, decode_c_generate_test, decode_h_generate_test = decoder_share(
            [
                z_prior_test, context_state_test, sequence_last_time_test,
                decode_c_generate_test,
                decode_h_generate_test * condition_intensity_test
            ])
        sequence_last_time_test = sequence_next_visit_test
        sequence_next_visit_test = tf.reshape(sequence_last_time_test,
                                              [batch_test, -1, feature_dims])

        generated_trajectory_test = tf.concat(
            (generated_trajectory_test, sequence_next_visit_test), axis=1)

    mse_generated_test = tf.reduce_mean(
        tf.keras.losses.mse(
            input_x_test[:,
                         previous_visit:previous_visit + predicted_visit, :],
            generated_trajectory_test))
    mae_generated_test = tf.reduce_mean(
        tf.keras.losses.mae(
            input_x_test[:,
                         previous_visit:previous_visit + predicted_visit, :],
            generated_trajectory_test))

    r_value_all = []
    p_value_all = []

    for r in range(predicted_visit):
        x_ = tf.reshape(input_x_test[:, previous_visit + r, :], (-1, ))
        y_ = tf.reshape(generated_trajectory_test[:, r, :], (-1, ))
        r_value_ = stats.pearsonr(x_, y_)
        r_value_all.append(r_value_[0])
        p_value_all.append(r_value_[1])

    return mse_generated_test, mae_generated_test, np.mean(r_value_all)
def train(hidden_size, z_dims, l2_regularization, learning_rate, kl_imbalance,
          reconstruction_imbalance, generated_mse_imbalance,
          likelihood_imbalance):
    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)

    train_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_train_.npy').reshape(
            -1, 6, 30)
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_validate_.npy').reshape(-1, 6, 30)
    test_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_test_.npy').reshape(
            -1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    previous_visit = 1
    predicted_visit = 5

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50
    #
    # hidden_size = 2 ** (int(hidden_size))
    # z_dims = 2 ** (int(z_dims))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # kl_imbalance = 10 ** kl_imbalance
    # reconstruction_imbalance = 10 ** reconstruction_imbalance
    # generated_mse_imbalance = 10 ** generated_mse_imbalance
    # likelihood_imbalance = 10 ** likelihood_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(
        previous_visit, predicted_visit))
    print(
        'hidden_size{}----z_dims{}------learning_rate{}----l2_regularization{}---'
        'kl_imbalance{}----reconstruction_imbalance '
        ' {}----generated_mse_imbalance{}----'.format(
            hidden_size, z_dims, learning_rate, l2_regularization,
            kl_imbalance, reconstruction_imbalance, generated_mse_imbalance))

    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    prior_net = Prior(z_dims=z_dims)
    post_net = Post(z_dims=z_dims)
    hawkes_process = HawkesProcess()

    logged = set()
    max_loss = 0.001
    max_pace = 0.0001

    loss = 0
    count = 0
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)
        batch = input_x_train.shape[0]

        with tf.GradientTape() as tape:
            generated_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            reconstruction_trajectory = tf.zeros(
                shape=[batch, 0, feature_dims])
            z_log_var_post_all = tf.zeros(shape=[batch, 0, z_dims])
            z_mean_post_all = tf.zeros(shape=[batch, 0, z_dims])
            z_mean_prior_all = tf.zeros(shape=[batch, 0, z_dims])
            z_log_var_prior_all = tf.zeros(shape=[batch, 0, z_dims])
            probability_likelihood = tf.zeros(shape=[batch, 0, 1])

            for predicted_visit_ in range(predicted_visit):
                sequence_time_current_time = input_x_train[:,
                                                           predicted_visit_ +
                                                           previous_visit, :]
                sequence_time_last_time = input_x_train[:, predicted_visit_ +
                                                        previous_visit - 1, :]
                for previous_visit_ in range(previous_visit +
                                             predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share(
                        [sequence_time, encode_c, encode_h])
                context_state = encode_h

                if predicted_visit_ == 0:
                    decode_c_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                    decode_c_reconstruct = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_reconstruct = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                z_prior, z_mean_prior, z_log_var_prior = prior_net(
                    context_state)
                encode_c, encode_h = encode_share(
                    [sequence_time_current_time, encode_c, encode_h])
                z_post, z_mean_post, z_log_var_post = post_net(
                    [context_state, encode_h])

                current_time_index_shape = tf.ones(
                    shape=[predicted_visit_ + previous_visit])
                condition_intensity, likelihood = hawkes_process(
                    [input_t_train, current_time_index_shape])
                generated_next_visit, decode_c_generate, decode_h_generate = decoder_share(
                    [
                        z_prior, context_state, sequence_time_last_time,
                        decode_c_generate,
                        decode_h_generate * condition_intensity
                    ])
                likelihood = tf.reshape(likelihood, [batch, -1, 1])
                probability_likelihood = tf.concat(
                    (probability_likelihood, likelihood), axis=1)
                generated_next_visit = tf.reshape(generated_next_visit,
                                                  [batch, -1, feature_dims])
                generated_trajectory = tf.concat(
                    (generated_trajectory, generated_next_visit), axis=1)

                construct_next_visit, decode_c_reconstruct, decode_h_reconstruct = decoder_share(
                    [
                        z_post, context_state, sequence_time_last_time,
                        decode_c_reconstruct,
                        decode_h_reconstruct * condition_intensity
                    ])
                construct_next_visit = tf.reshape(construct_next_visit,
                                                  [batch, -1, feature_dims])
                reconstruction_trajectory = tf.concat(
                    (reconstruction_trajectory, construct_next_visit), axis=1)

                z_mean_prior_all = tf.concat(
                    (z_mean_prior_all,
                     tf.reshape(z_mean_prior, [batch, -1, z_dims])),
                    axis=1)
                z_log_var_prior_all = tf.concat(
                    (z_log_var_prior_all,
                     tf.reshape(z_log_var_prior, [batch, -1, z_dims])),
                    axis=1)

                z_mean_post_all = tf.concat(
                    (z_mean_post_all,
                     tf.reshape(z_mean_post, [batch, -1, z_dims])),
                    axis=1)
                z_log_var_post_all = tf.concat(
                    (z_log_var_post_all,
                     tf.reshape(z_log_var_post, [batch, -1, z_dims])),
                    axis=1)

            mse_reconstruction = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :],
                    reconstruction_trajectory))
            mse_generated = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :],
                    reconstruction_trajectory))

            std_post = tf.math.sqrt(tf.exp(z_log_var_post_all))
            std_prior = tf.math.sqrt(tf.exp(z_log_var_prior_all))

            kl_loss_element = 0.5 * (
                2 * tf.math.log(tf.maximum(std_prior, 1e-9)) -
                2 * tf.math.log(tf.maximum(std_post, 1e-9)) +
                (tf.math.pow(std_post, 2) + tf.math.pow(
                    (z_mean_post_all - z_mean_prior_all), 2)) /
                tf.maximum(tf.math.pow(std_prior, 2), 1e-9) - 1)
            kl_loss_all = tf.reduce_mean(kl_loss_element)
            # print('kl_loss---{}'.format(kl_loss_all))

            likelihood_loss = tf.reduce_mean(probability_likelihood)

            loss += mse_reconstruction * reconstruction_imbalance + mse_generated * generated_mse_imbalance + kl_loss_all * kl_imbalance + likelihood_loss * likelihood_imbalance

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in prior_net.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in post_net.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in hawkes_process.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            tape.watch(variables)

            gradient = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradient, variables))

            if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
                logged.add(train_set.epoch_completed)
                loss_pre = mse_generated

                mse_generated = tf.reduce_mean(
                    tf.keras.losses.mse(
                        input_x_train[:, previous_visit:previous_visit +
                                      predicted_visit, :],
                        reconstruction_trajectory))
                loss_diff = loss_pre - mse_generated

                if mse_generated > max_loss:
                    count = 0

                else:
                    if loss_diff > max_pace:
                        count = 0

                    else:
                        count += 1
                if count > 9:
                    break

                input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
                input_t_test = tf.cast(test_set[:, :, 0], tf.float32)
                batch_test = input_x_test.shape[0]
                generated_trajectory_test = tf.zeros(
                    shape=[batch_test, 0, feature_dims])
                for predicted_visit_ in range(predicted_visit):
                    for previous_visit_ in range(previous_visit):
                        sequence_time_test = input_x_test[:,
                                                          previous_visit_, :]
                        if previous_visit_ == 0:
                            encode_c_test = tf.Variable(
                                tf.zeros(shape=[batch_test, hidden_size]))
                            encode_h_test = tf.Variable(
                                tf.zeros(shape=[batch_test, hidden_size]))
                        encode_c_test, encode_h_test = encode_share(
                            [sequence_time_test, encode_c_test, encode_h_test])

                    if predicted_visit_ != 0:
                        for i in range(predicted_visit_):
                            sequence_input_t = generated_trajectory_test[:,
                                                                         i, :]
                            encode_c_test, encode_h_test = encode_share([
                                sequence_input_t, encode_c_test, encode_h_test
                            ])

                    context_state_test = encode_h_test
                    z_prior_test, z_mean_prior_test, z_log_var_prior = prior_net(
                        context_state_test)

                    if predicted_visit_ == 0:
                        decode_c_generate_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        decode_h_generate_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        sequence_last_time_test = input_x_test[:,
                                                               predicted_visit_
                                                               +
                                                               previous_visit, :]

                    current_time_index_shape_test = tf.ones(
                        shape=[previous_visit + predicted_visit_])

                    condition_intensity_test, likelihood_test = hawkes_process(
                        [input_t_test, current_time_index_shape_test])

                    sequence_next_visit_test, decode_c_generate_test, decode_h_generate_test = decoder_share(
                        [
                            z_prior_test, context_state_test,
                            sequence_last_time_test, decode_c_generate_test,
                            decode_h_generate_test * condition_intensity_test
                        ])
                    sequence_last_time_test = sequence_next_visit_test
                    sequence_next_visit_test = tf.reshape(
                        sequence_last_time_test,
                        [batch_test, -1, feature_dims])

                    generated_trajectory_test = tf.concat(
                        (generated_trajectory_test, sequence_next_visit_test),
                        axis=1)

                mse_generated_test = tf.reduce_mean(
                    tf.keras.losses.mse(
                        input_x_test[:, previous_visit:previous_visit +
                                     predicted_visit, :],
                        generated_trajectory_test))
                mae_generated_test = tf.reduce_mean(
                    tf.keras.losses.mae(
                        input_x_test[:, previous_visit:previous_visit +
                                     predicted_visit, :],
                        generated_trajectory_test))

                r_value_all = []
                p_value_all = []

                for r in range(predicted_visit):
                    x_ = tf.reshape(input_x_test[:, previous_visit + r, :],
                                    (-1, ))
                    y_ = tf.reshape(generated_trajectory_test[:, r, :], (-1, ))
                    r_value_ = stats.pearsonr(x_, y_)
                    r_value_all.append(r_value_[0])
                    p_value_all.append(r_value_[1])

                print(
                    "epoch  {}---train_mse_generate {}--train_reconstruct {}--train_kl "
                    "{}--test_mse {}--test_mae  {}----r_value {}---count {}".
                    format(train_set.epoch_completed, mse_generated,
                           mse_reconstruction, kl_loss_all, mse_generated_test,
                           mae_generated_test, np.mean(r_value_all), count))

    tf.compat.v1.reset_default_graph()
    # return mse_generated_test, mae_generated_test, np.mean(r_value_all)
    return -1 * mse_generated_test
示例#5
0
def train(hidden_size, z_dims, l2_regularization, learning_rate, n_disc,
          generated_mse_imbalance, generated_loss_imbalance, kl_imbalance,
          reconstruction_mse_imbalance, likelihood_imbalance):
    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)

    # train_set = np.load('../../Trajectory_generate/dataset_file/HF_train_.npy').reshape(-1, 6, 30)[:, :, :]
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_validate_.npy').reshape(-1, 6, 30)[:, :, :]
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_test_.npy').reshape(-1, 6, 30)[:, :, :]

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    train_set = np.load(
        '../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy'
    ).reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate_dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)[:1072,:, : ]
    test_set = np.load(
        '../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy'
    ).reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 10

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50

    hidden_size = 2**(int(hidden_size))
    z_dims = 2**(int(z_dims))
    l2_regularization = 10**l2_regularization
    learning_rate = 10**learning_rate
    n_disc = int(n_disc)
    generated_mse_imbalance = 10**generated_mse_imbalance
    generated_loss_imbalance = 10**generated_loss_imbalance
    kl_imbalance = 10**kl_imbalance
    reconstruction_mse_imbalance = 10**reconstruction_mse_imbalance
    likelihood_imbalance = 10**likelihood_imbalance

    print('feature_dims---{}'.format(feature_dims))

    print('previous_visit---{}---predicted_visit----{}-'.format(
        previous_visit, predicted_visit))

    print(
        'hidden_size---{}---z_dims---{}---l2_regularization---{}---learning_rate---{}--n_disc---{}-'
        'generated_mse_imbalance---{}---generated_loss_imbalance---{}---'
        'kl_imbalance---{}---reconstruction_mse_imbalance---{}---'
        'likelihood_imbalance---{}'.format(
            hidden_size, z_dims, l2_regularization, learning_rate, n_disc,
            generated_mse_imbalance, generated_loss_imbalance, kl_imbalance,
            reconstruction_mse_imbalance, likelihood_imbalance))

    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    discriminator = Discriminator(predicted_visit=predicted_visit,
                                  hidden_size=hidden_size,
                                  previous_visit=previous_visit)

    post_net = Post(z_dims=z_dims)
    prior_net = Prior(z_dims=z_dims)

    hawkes_process = HawkesProcess()
    loss = 0
    count = 0
    optimizer_generation = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    optimizer_discriminator = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    logged = set()
    max_loss = 0.001
    max_pace = 0.0001

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)
        batch = input_train.shape[0]

        with tf.GradientTape() as gen_tape, tf.GradientTape(
                persistent=True) as disc_tape:
            generated_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            probability_likelihood = tf.zeros(shape=[batch, 0, 1])
            reconstructed_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            z_mean_post_all = tf.zeros(shape=[batch, 0, z_dims])
            z_log_var_post_all = tf.zeros(shape=[batch, 0, z_dims])
            z_mean_prior_all = tf.zeros(shape=[batch, 0, z_dims])
            z_log_var_prior_all = tf.zeros(shape=[batch, 0, z_dims])
            for predicted_visit_ in range(predicted_visit):
                sequence_last_time = input_x_train[:, previous_visit +
                                                   predicted_visit_ - 1, :]
                sequence_current_time = input_x_train[:, previous_visit +
                                                      predicted_visit_, :]
                for previous_visit_ in range(previous_visit +
                                             predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share(
                        [sequence_time, encode_c, encode_h])
                context_state = encode_h  # h_i
                encode_c, encode_h = encode_share(
                    [sequence_current_time, encode_c, encode_h])  # h_(i+1)

                if predicted_visit_ == 0:
                    decode_c_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_generate = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                    decode_c_reconstruction = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h_reconstruction = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                z_post, z_mean_post, z_log_var_post = post_net(
                    [context_state, encode_h])
                z_prior, z_mean_prior, z_log_var_prior = prior_net(
                    context_state)

                current_time_index_shape = tf.ones(
                    shape=[previous_visit + predicted_visit_])
                condition_value, likelihood = hawkes_process(
                    [input_t_train, current_time_index_shape])
                probability_likelihood = tf.concat(
                    (probability_likelihood,
                     tf.reshape(likelihood, [batch, -1, 1])),
                    axis=1)
                probability_likelihood = tf.keras.activations.softmax(
                    probability_likelihood)
                # generation
                generated_next_visit, decode_c_generate, decode_h_generate = decoder_share(
                    [
                        z_prior, context_state, sequence_last_time,
                        decode_c_generate, decode_h_generate * condition_value
                    ])
                # reconstruction
                reconstructed_next_visit, decode_c_reconstruction, decode_h_reconstruction = decoder_share(
                    [
                        z_post, context_state, sequence_last_time,
                        decode_c_reconstruction,
                        decode_h_reconstruction * condition_value
                    ])

                reconstructed_trajectory = tf.concat(
                    (reconstructed_trajectory,
                     tf.reshape(reconstructed_next_visit,
                                [batch, -1, feature_dims])),
                    axis=1)
                generated_trajectory = tf.concat(
                    (generated_trajectory,
                     tf.reshape(generated_next_visit,
                                [batch, -1, feature_dims])),
                    axis=1)

                z_mean_post_all = tf.concat(
                    (z_mean_post_all,
                     tf.reshape(z_mean_post, [batch, -1, z_dims])),
                    axis=1)
                z_mean_prior_all = tf.concat(
                    (z_mean_prior_all,
                     tf.reshape(z_mean_prior, [batch, -1, z_dims])),
                    axis=1)

                z_log_var_post_all = tf.concat(
                    (z_log_var_post_all,
                     tf.reshape(z_log_var_post, [batch, -1, z_dims])),
                    axis=1)
                z_log_var_prior_all = tf.concat(
                    (z_log_var_prior_all,
                     tf.reshape(z_log_var_prior, [batch, -1, z_dims])),
                    axis=1)

            d_real_pre_, d_fake_pre_ = discriminator(input_x_train,
                                                     generated_trajectory)
            d_real_pre_loss = cross_entropy(tf.ones_like(d_real_pre_),
                                            d_real_pre_)
            d_fake_pre_loss = cross_entropy(tf.zeros_like(d_fake_pre_),
                                            d_fake_pre_)
            d_loss = d_real_pre_loss + d_fake_pre_loss

            gen_loss = cross_entropy(tf.ones_like(d_fake_pre_), d_fake_pre_)
            generated_mse_loss = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))
            reconstructed_mse_loss = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :],
                    reconstructed_trajectory))

            std_post = tf.math.sqrt(tf.exp(z_log_var_post_all))
            std_prior = tf.math.sqrt(tf.exp(z_log_var_prior_all))

            kl_loss_element = 0.5 * (
                2 * tf.math.log(tf.maximum(std_prior, 1e-9)) -
                2 * tf.math.log(tf.maximum(std_post, 1e-9)) +
                (tf.square(std_post) +
                 (tf.square(z_mean_post_all - z_mean_prior_all)) /
                 (tf.maximum(tf.square(std_prior), 1e-9))) - 1)
            kl_loss = tf.reduce_mean(kl_loss_element)

            likelihood_loss = tf.reduce_mean(probability_likelihood)

            loss += generated_mse_loss * generated_mse_imbalance +\
                    reconstructed_mse_loss * reconstruction_mse_imbalance + \
                    kl_loss * kl_imbalance + likelihood_loss * likelihood_imbalance \
                    + gen_loss * generated_loss_imbalance

            for weight in discriminator.trainable_variables:
                d_loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in post_net.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in prior_net.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in hawkes_process.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

        for disc in range(n_disc):
            gradient_disc = disc_tape.gradient(
                d_loss, discriminator.trainable_variables)
            optimizer_discriminator.apply_gradients(
                zip(gradient_disc, discriminator.trainable_variables))

        gradient_gen = gen_tape.gradient(loss, variables)
        optimizer_generation.apply_gradients(zip(gradient_gen, variables))

        if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:

            logged.add(train_set.epoch_completed)
            loss_pre = generated_mse_loss
            mse_generated = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            loss_diff = loss_pre - mse_generated

            if mse_generated > max_loss:
                count = 0
            else:
                if loss_diff > max_pace:
                    count = 0
                else:
                    count += 1
            if count > 9:
                break

            input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
            input_t_test = tf.cast(test_set[:, :, 0], tf.float32)

            batch_test = test_set.shape[0]
            generated_trajectory_test = tf.zeros(
                shape=[batch_test, 0, feature_dims])
            for predicted_visit_ in range(predicted_visit):
                for previous_visit_ in range(previous_visit):
                    sequence_time_test = input_x_test[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        encode_h_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))

                    encode_c_test, encode_h_test = encode_share(
                        [sequence_time_test, encode_c_test, encode_h_test])

                if predicted_visit_ != 0:
                    for i in range(predicted_visit_):
                        encode_c_test, encode_h_test = encode_share([
                            generated_trajectory_test[:, i, :], encode_c_test,
                            encode_h_test
                        ])

                context_state_test = encode_h_test

                if predicted_visit_ == 0:
                    decode_c_generate_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    decode_h_generate_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    sequence_last_time_test = input_x_test[:, previous_visit +
                                                           predicted_visit_ -
                                                           1, :]

                z_prior_test, z_mean_prior_test, z_log_var_prior_test = prior_net(
                    context_state_test)
                current_time_index_shape_test = tf.ones(
                    [previous_visit + predicted_visit_])
                intensity_value_test, likelihood_test = hawkes_process(
                    [input_t_test, current_time_index_shape_test])

                generated_next_visit_test, decode_c_generate_test, decode_h_generate_test = decoder_share(
                    [
                        z_prior_test, context_state_test,
                        sequence_last_time_test, decode_c_generate_test,
                        decode_h_generate_test * intensity_value_test
                    ])
                generated_trajectory_test = tf.concat(
                    (generated_trajectory_test,
                     tf.reshape(generated_next_visit_test,
                                [batch_test, -1, feature_dims])),
                    axis=1)
                sequence_last_time_test = generated_next_visit_test

            mse_generated_test = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))
            mae_generated_test = tf.reduce_mean(
                tf.keras.losses.mae(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))

            r_value_all = []
            for patient in range(batch_test):
                r_value = 0.0
                for feature in range(feature_dims):
                    x_ = input_x_test[patient, previous_visit:previous_visit +
                                      predicted_visit,
                                      feature].numpy().reshape(
                                          predicted_visit, 1)
                    y_ = generated_trajectory_test[patient, :,
                                                   feature].numpy().reshape(
                                                       predicted_visit, 1)
                    r_value += DynamicTimeWarping(x_, y_)
                r_value_all.append(r_value / 29.0)

            print(
                'epoch ---{}---train_mse_generated---{}---likelihood_loss{}---'
                'train_mse_reconstruct---{}---train_kl---{}---'
                'test_mse---{}---test_mae---{}---'
                'r_value_test---{}---count---{}'.format(
                    train_set.epoch_completed, generated_mse_loss,
                    likelihood_loss, reconstructed_mse_loss,
                    kl_loss, mse_generated_test, mae_generated_test,
                    np.mean(r_value_all), count))
    tf.compat.v1.reset_default_graph()
    # return mse_generated_test, mae_generated_test, np.mean(r_value_all)
    return -1 * mse_generated_test
示例#6
0
def train(hidden_size, learning_rate, l2_regularization, n_disc,
          generated_mse_imbalance, generated_loss_imbalance,
          likelihood_imbalance):
    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)

    train_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_train_.npy').reshape(
            -1, 6, 30)
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_validate_.npy').reshape(-1, 6, 30)
    test_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_test_.npy').reshape(
            -1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    # train_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 3

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50

    # hidden_size = 2 ** (int(hidden_size))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # n_disc = int(n_disc)
    # generated_mse_imbalance = 10 ** generated_mse_imbalance
    # generated_loss_imbalance = 10 ** generated_loss_imbalance
    # likelihood_imbalance = 10 ** likelihood_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(
        previous_visit, predicted_visit))

    print(
        'hidden_size---{}---learning_rate---{}---l2_regularization---{}---n_disc---{}'
        'generated_mse_imbalance---{}---generated_loss_imbalance---{}---'
        'likelihood_imbalance---{}'.format(hidden_size, learning_rate,
                                           l2_regularization, n_disc,
                                           generated_mse_imbalance,
                                           generated_loss_imbalance,
                                           likelihood_imbalance))
    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    hawkes_process = HawkesProcess()
    discriminator = Discriminator(previous_visit=previous_visit,
                                  predicted_visit=predicted_visit,
                                  hidden_size=hidden_size)

    logged = set()
    max_loss = 0.001
    max_pace = 0.0001
    count = 0
    loss = 0
    optimizer_generation = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    optimizer_discriminator = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)
        batch = input_train.shape[0]

        with tf.GradientTape() as gen_tape, tf.GradientTape(
                persistent=True) as disc_tape:
            generated_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            probability_likelihood = tf.zeros(shape=[batch, 0, 1])
            for predicted_visit_ in range(predicted_visit):
                sequence_last_time = input_x_train[:, previous_visit +
                                                   predicted_visit_ - 1, :]
                for previous_visit_ in range(previous_visit +
                                             predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share(
                        [sequence_time, encode_c, encode_h])
                context_state = encode_h

                if predicted_visit_ == 0:
                    decode_c = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                current_time_index_shape = tf.ones(
                    shape=[previous_visit + predicted_visit_])
                intensity_value, likelihood = hawkes_process(
                    [input_t_train, current_time_index_shape])
                probability_likelihood = tf.concat(
                    (probability_likelihood,
                     tf.reshape(likelihood, [batch, -1, 1])),
                    axis=1)

                generated_next_visit, decode_c, decode_h = decoder_share([
                    sequence_last_time, context_state, decode_c,
                    decode_h * intensity_value
                ])
                generated_trajectory = tf.concat(
                    (generated_trajectory,
                     tf.reshape(generated_next_visit,
                                [batch, -1, feature_dims])),
                    axis=1)

            d_real_pre_, d_fake_pre_ = discriminator(input_x_train,
                                                     generated_trajectory)
            d_real_pre_loss = cross_entropy(tf.ones_like(d_real_pre_),
                                            d_real_pre_)
            d_fake_pre_loss = cross_entropy(tf.zeros_like(d_fake_pre_),
                                            d_fake_pre_)
            d_loss = d_real_pre_loss + d_fake_pre_loss

            gen_loss = cross_entropy(tf.ones_like(d_fake_pre_), d_fake_pre_)
            generated_mse_loss = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            likelihood_loss = tf.reduce_mean(probability_likelihood)

            loss += generated_mse_loss * generated_mse_imbalance + likelihood_loss * likelihood_imbalance + \
                    gen_loss * generated_loss_imbalance

            for weight in discriminator.trainable_variables:
                d_loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in hawkes_process.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

        for disc in range(n_disc):
            gradient_disc = disc_tape.gradient(
                d_loss, discriminator.trainable_variables)
            optimizer_discriminator.apply_gradients(
                zip(gradient_disc, discriminator.trainable_variables))

        gradient_gen = gen_tape.gradient(loss, variables)
        optimizer_generation.apply_gradients(zip(gradient_gen, variables))

        if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
            logged.add(train_set.epoch_completed)
            loss_pre = generated_mse_loss

            mse_generated = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            loss_diff = loss_pre - mse_generated

            if mse_generated > max_loss:
                count = 0
            else:
                if loss_diff > max_pace:
                    count = 0
                else:
                    count += 1
            if count > 9:
                break

            input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
            input_t_test = tf.cast(test_set[:, :, 0], tf.float32)

            batch_test = test_set.shape[0]
            generated_trajectory_test = tf.zeros(
                shape=[batch_test, 0, feature_dims])
            for predicted_visit_ in range(predicted_visit):
                for previous_visit_ in range(previous_visit):
                    sequence_time_test = input_x_test[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        encode_h_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))

                    encode_c_test, encode_h_test = encode_share(
                        [sequence_time_test, encode_c_test, encode_h_test])

                if predicted_visit_ != 0:
                    for i in range(predicted_visit_):
                        encode_c_test, encode_h_test = encode_share([
                            generated_trajectory_test[:, i, :], encode_c_test,
                            encode_h_test
                        ])

                context_state_test = encode_h_test

                if predicted_visit_ == 0:
                    decode_c_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    decode_h_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    sequence_last_time_test = input_x_test[:, previous_visit +
                                                           predicted_visit_ -
                                                           1, :]

                current_time_index_shape = tf.ones(
                    [previous_visit + predicted_visit_])
                intensity_value, likelihood = hawkes_process(
                    [input_t_test, current_time_index_shape])
                generated_next_visit, decode_c_test, decode_h_test = decoder_share(
                    [
                        sequence_last_time_test, context_state_test,
                        decode_c_test, decode_h_test * intensity_value
                    ])
                generated_trajectory_test = tf.concat(
                    (generated_trajectory_test,
                     tf.reshape(generated_next_visit,
                                [batch_test, -1, feature_dims])),
                    axis=1)
                sequence_last_time_test = generated_next_visit

            mse_generated_test = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))
            mae_generated_test = tf.reduce_mean(
                tf.keras.losses.mae(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))

            r_value_all = []
            for patient in range(batch_test):
                r_value = 0.0
                for feature in range(feature_dims):
                    x_ = input_x_test[patient, previous_visit:,
                                      feature].numpy().reshape(
                                          predicted_visit, 1)
                    y_ = generated_trajectory_test[patient, :,
                                                   feature].numpy().reshape(
                                                       predicted_visit, 1)
                    r_value += DynamicTimeWarping(x_, y_)
                r_value_all.append(r_value / 29.0)

            print(
                '------epoch{}------mse_loss{}----mae_loss{}------predicted_r_value---{}--'
                '-count  {}'.format(train_set.epoch_completed,
                                    mse_generated_test, mae_generated_test,
                                    np.mean(r_value_all), count))

            # r_value_all = []
            # p_value_all = []
            # r_value_spearman = []
            # r_value_kendalltau = []
            # for visit in range(predicted_visit):
            #     for feature in range(feature_dims):
            #         x_ = input_x_test[:, previous_visit+visit, feature]
            #         y_ = generated_trajectory_test[:, visit, feature]
            #         r_value_ = stats.pearsonr(x_, y_)
            #         r_value_spearman_ = stats.spearmanr(x_, y_)
            #         r_value_kendalltau_ = stats.kendalltau(x_, y_)
            #         if not np.isnan(r_value_[0]):
            #             r_value_all.append(np.abs(r_value_[0]))
            #             p_value_all.append(np.abs(r_value_[1]))
            #         if not np.isnan(r_value_spearman_[0]):
            #             r_value_spearman.append(np.abs(r_value_spearman_[0]))
            #         if not np.isnan(r_value_kendalltau_[0]):
            #             r_value_kendalltau.append(np.abs(r_value_kendalltau_[0]))
            # print('------epoch{}------mse_loss{}----mae_loss{}------predicted_r_value---{}--'
            #       'r_value_spearman---{}---r_value_kendalltau---{}--count  {}'.format(train_set.epoch_completed,
            #                                                                           mse_generated_test,

# 																		  mae_generated_test,
#                                                                           np.mean(r_value_all),
#                                                                           np.mean(r_value_spearman),
#                                                                           np.mean(r_value_kendalltau),
#                                                                           count))

    tf.compat.v1.reset_default_graph()
    return mse_generated_test, mae_generated_test, np.mean(r_value_all)
def train(hidden_size, l2_regularization, learning_rate, generated_imbalance, likelihood_imbalance):
    train_set = np.load("../../Trajectory_generate/dataset_file/HF_train_.npy").reshape(-1, 6, 30)
    test_set = np.load("../../Trajectory_generate/dataset_file/HF_test_.npy").reshape(-1, 6, 30)
    # test_set = np.load("../../Trajectory_generate/dataset_file/HF_validate_.npy").reshape(-1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    # train_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 3

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50

    # hidden_size = 2 ** (int(hidden_size))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # generated_imbalance = 10 ** generated_imbalance
    # likelihood_imbalance = 10 ** likelihood_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(previous_visit, predicted_visit))

    print('hidden_size----{}---'
          'l2_regularization---{}---'
          'learning_rate---{}---'
          'generated_imbalance---{}---'
          'likelihood_imbalance---{}'.
          format(hidden_size, l2_regularization, learning_rate,
                 generated_imbalance, likelihood_imbalance))

    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    encode_share = Encoder(hidden_size=hidden_size)
    hawkes_process = HawkesProcess()

    logged = set()
    max_loss = 0.01
    max_pace = 0.001
    loss = 0

    count = 0
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        batch = input_train.shape[0]
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)

        with tf.GradientTape() as tape:
            predicted_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            likelihood_all = tf.zeros(shape=[batch, 0, 1])
            for predicted_visit_ in range(predicted_visit):
                sequence_time_last_time = input_x_train[:, previous_visit+predicted_visit_-1, :]
                for previous_visit_ in range(previous_visit+predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share([sequence_time, encode_c, encode_h])
                context_state = encode_h

                if predicted_visit_ == 0:
                    decode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                    decode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                current_time_index_shape = tf.ones(shape=[predicted_visit_+previous_visit])
                condition_intensity, likelihood = hawkes_process([input_t_train, current_time_index_shape])
                likelihood_all = tf.concat((likelihood_all, tf.reshape(likelihood, [batch, -1, 1])), axis=1)
                generated_next_visit, decode_c, decode_h = decoder_share([sequence_time_last_time, context_state, decode_c, decode_h*condition_intensity])
                predicted_trajectory = tf.concat((predicted_trajectory, tf.reshape(generated_next_visit, [batch, -1, feature_dims])), axis=1)

            mse_generated_loss = tf.reduce_mean(tf.keras.losses.mse(input_x_train[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory))
            mae_generated_loss = tf.reduce_mean(tf.keras.losses.mae(input_x_train[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory))
            likelihood_loss = tf.reduce_mean(likelihood_all)

            loss += mse_generated_loss * generated_imbalance + likelihood_loss * likelihood_imbalance

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in hawkes_process.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            gradient = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradient, variables))

            if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
                logged.add(train_set.epoch_completed)

                loss_pre = mse_generated_loss
                mse_generated_loss = tf.reduce_mean(
                    tf.keras.losses.mse(input_x_train[:, previous_visit:previous_visit + predicted_visit, :],
                                        predicted_trajectory))

                loss_diff = loss_pre - mse_generated_loss

                if max_loss < mse_generated_loss:
                    count = 0
                else:
                    if max_pace < loss_diff:
                        count = 0

                    else:
                        count += 1
                if count > 9:
                    break

                input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
                input_t_test = tf.cast(test_set[:, :, 0], tf.float32)

                batch_test = input_x_test.shape[0]
                predicted_trajectory_test = tf.zeros(shape=[batch_test, 0, feature_dims])
                for predicted_visit_ in range(predicted_visit):
                    for previous_visit_ in range(previous_visit):
                        sequence_time_test = input_x_test[:, previous_visit_, :]
                        if previous_visit_ == 0:
                            encode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                            encode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        encode_c_test, encode_h_test = encode_share([sequence_time_test, encode_c_test, encode_h_test])

                    if predicted_visit_ != 0:
                        for i in range(predicted_visit_):
                            encode_c, encode_h_test = encode_share([predicted_trajectory_test[:, i, :], encode_c_test, encode_h_test])
                    context_state_test = encode_h_test

                    if predicted_visit_ == 0:
                        decode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        decode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        sequence_time_last_time_test = input_x_test[:, predicted_visit_+previous_visit-1, :]

                    current_time_index_shape_test = tf.ones(shape=[previous_visit+predicted_visit_])
                    condition_intensity_test, likelihood_test = hawkes_process([input_t_test, current_time_index_shape_test])

                    sequence_next_visit_test, decode_c_test, decode_h_test = decoder_share([sequence_time_last_time_test, context_state_test, decode_c_test, decode_h_test*condition_intensity_test])
                    predicted_trajectory_test = tf.concat((predicted_trajectory_test, tf.reshape(sequence_next_visit_test, [batch_test, -1, feature_dims])), axis=1)
                    sequence_time_last_time_test = sequence_next_visit_test

                mse_generated_loss_test = tf.reduce_mean(tf.keras.losses.mse(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))
                mae_generated_loss_test = tf.reduce_mean(tf.keras.losses.mae(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))

                r_value_all = []
                for patient in range(batch_test):
                    r_value = 0.0
                    for feature in range(feature_dims):
                        x_ = input_x_test[patient, previous_visit:, feature].numpy().reshape(predicted_visit, 1)
                        y_ = predicted_trajectory_test[patient, :, feature].numpy().reshape(predicted_visit, 1)
                        r_value += DynamicTimeWarping(x_, y_)
                    r_value_all.append(r_value / 29.0)
                print("epoch  {}---train_mse_generate {}- - "
                      "mae_generated_loss--{}--test_mse {}--test_mae  "
                      "{}--r_value {}-count {}".format(train_set.epoch_completed,
                                                       mse_generated_loss,
                                                       mae_generated_loss,
                                                       mse_generated_loss_test,
                                                       mae_generated_loss_test,
                                                       np.mean(r_value_all),
                                                       count))


                # r_value_all = []
                # p_value_all = []
                # r_value_spearman_all = []
                # r_value_kendall_all = []
                # for visit in range(predicted_visit):
                #     for feature in range(feature_dims):
                #         x_ = input_x_test[:, previous_visit+visit, feature]
                #         y_ = predicted_trajectory_test[:, visit, feature]
                #         r_value_ = stats.pearsonr(x_, y_)
                #         r_value_spearman = stats.spearmanr(x_, y_)
                #         r_value_kendall = stats.kendalltau(x_, y_)
                #         if not np.isnan(r_value_[0]):
                #             r_value_all.append(np.abs(r_value_[0]))
                #             p_value_all.append(np.abs(r_value_[1]))
                #         if not np.isnan(r_value_spearman[0]):
                #             r_value_spearman_all.append(np.abs(r_value_spearman[0]))
                #         if not np.isnan(r_value_kendall[0]):
                #             r_value_kendall_all.append(np.abs(r_value_kendall[0]))

                # print("epoch  {}---train_mse_generate {}- - "
                #       "mae_generated_loss--{}--test_mse {}--test_mae  "
                #       "{}----r_value {}--r_spearman---{}-"
                #       "r_kendall---{}    -count {}".format(train_set.epoch_completed,
                #                                            mse_generated_loss,
                #                                            mae_generated_loss,
                #                                            mse_generated_loss_test,
                #                                            mae_generated_loss_test,
                #                                            np.mean(r_value_all),
                #                                            np.mean(r_value_spearman_all),
                #                                            np.mean(r_value_kendall_all),
                #                                            count))
    tf.compat.v1.reset_default_graph()
    return mse_generated_loss_test, mae_generated_loss_test, np.mean(r_value_all)