Exemplo n.º 1
0
def make_model(cnn3d,
               tgt_vocab,
               N=6,
               d_model=512,
               d_ff=2048,
               h=8,
               dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab), cnn3d)

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.named_parameters():
        if not p[0].startswith(
                "cnn3d") and p[1].requires_grad and p[1].dim() > 1:
            nn.init.xavier_uniform_(p[1])

    return model
Exemplo n.º 2
0
def validate_model(model, info, file_name):
    model.eval()
    decoder = Decoder()
    labels, paths = info
    results = []

    with torch.no_grad():
        for path, label in zip(paths, labels):
            spec = get_spectrogram(path)
            label = strip_label(label)

            output = model(spec)
            output = F.log_softmax(output, dim=2)
            output = output.permute(1, 0, 2)

            decoded_output = decoder.greedy_decode(output.numpy())
            results.append((label, decoded_output))

    with open(file_name, 'w') as f:
        for label, pred in results:
            f.write(f'\n\n{label}\n{pred}')
Exemplo n.º 3
0
import torch
import zipfile
import torchaudio
from glob import glob

from utils import (Decoder, read_audio, read_batch, split_into_batches,
                   prepare_model_input)

device = torch.device(
    'cpu')  # gpu also works, but our models are fast enough for CPU

model = torch.jit.load('en_v2_jit.model', map_location=device)
model.eval()

decoder = Decoder(model.labels)

# model, decoder, utils = torch.hub.load(repo_or_dir='~/silero-models',
#                                        model='silero_stt',
#                                        language='en', # also available 'de', 'es'
#                                        device=device)
# (read_batch, split_into_batches,
#  read_audio, prepare_model_input) = utils  # see function signature for details

# download a single file, any format compatible with TorchAudio (soundfile backend)
# torch.hub.download_url_to_file('https://opus-codec.org/static/examples/samples/speech_orig.wav',
#                                dst ='speech_orig.wav', progress=True)

test_files = glob('speech_orig.wav')
batches = split_into_batches(test_files, batch_size=10)
input = prepare_model_input(read_batch(batches[0]), device=device)
Exemplo n.º 4
0


if __name__ == '__main__':
    sym_spell = SymSpell(max_dictionary_edit_distance=2)
    dictionary_path = pkg_resources.resource_filename("symspellpy", "frequency_dictionary_en_82_765.txt")
    sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)

    with open('kite_reading.txt', 'r') as f:
        labels = f.readlines()

    files = [f'kite_clips_0/clip{i}.wav' for i in range(len(os.listdir('kite_clips_0')))]
    checkpoint = torch.load('ckpt-5.pth')
    model = checkpoint['model']
    model.eval()
    decoder = Decoder()

    with torch.no_grad():
        for label, file in zip(labels, files):
            waveform, sample_rate = torchaudio.load(file)
            waveform = waveform.unsqueeze(0)
            resample = torchaudio.transforms.Resample(sample_rate, 8000)
            transform = LogMelSpectrogram(8000, 400, 0.5, 128)
            waveform = resample(waveform)
            spec = transform(waveform)

            output = F.softmax(model(spec), dim=2)
            output = output.permute(1, 0, 2)
            beams = prefix_beam_search(output.squeeze(1).numpy(), k=10)

            all_beam_predictions = aggregate_predictions(beams)
Exemplo n.º 5
0
 def __init__(self):
     modelname = 'en_v3_jit.model'
     self.model = torch.jit.load(modelname)
     self.decoder = Decoder(self.model.labels)
Exemplo n.º 6
0
def main():
    # Load setting from json file
    with open('../lstm_save/dataset_param.json', 'r') as f:
        setting = json.load(f)
    
    f = open("../lstm_save/result_large.txt", "w+")
        
    num_input_tokens = setting['num_input_tokens']
    num_output_tokens = setting['num_output_tokens']
    max_input_len = setting['max_input_len']
    max_output_len = setting['max_output_len']
    input_token_index = setting['input_token_index']
    output_token_index = setting['output_token_index']
    # hidden_dim = setting['hidden_dim']
    hidden_dim = 256

    reverse_output_token_index = dict((i, char) for char, i in output_token_index.items())

    # Load model from h5 file
    model = load_model(
        "../lstm_save/model.h5",
        custom_objects={"match_rate":match_rate}
    )

    model.summary()

    decoder = Decoder(
        model=model,
        hidden_dim=hidden_dim,
        num_input_tokens=num_input_tokens,
        num_output_tokens=num_output_tokens,
        max_input_len=max_input_len,
        max_output_len=max_output_len,
        input_token_index=input_token_index,
        output_token_index=output_token_index,
        reverse_output_token_index=reverse_output_token_index
    )

    # Test1
    # print(decoder.predict("x|y"))
    # print(verify_equivalent("-2*(~(x&y))-1", "x|y"))
    # exit()

    # # Test2
    path = "../../data/linear/test/test_data.csv"
    test_data = pd.read_csv(path, header=None)
    mba_exprs, targ_exprs = test_data[0], test_data[1]

    wrong_predict_statistic = []
    correct_predict_count = 0
    z3_verify_correct_count = 0
    test_count = len(test_data)
    time_sum = 0
    max_time = 0
    min_time = 1
    total_len = 0
    for idx in range(test_count):
        print("No.%d" % (idx + 1), end=' ', file=f)
        print("No.%d" % (idx + 1), end=' ')
        print("=" * 50, file=f)
        print("MBA expr:", mba_exprs[idx], file=f)
        print("Targ expr:", targ_exprs[idx], file=f)
        start_time = time.time()
        predict_expr = decoder.predict(mba_exprs[idx])
        total_len += len(predict_expr)
        print("Pred expr:", predict_expr, file=f)
        end_time = time.time()
        consume_time = end_time - start_time
        time_sum += consume_time
        if max_time < consume_time:
            max_time = consume_time
        if min_time > consume_time:
            min_time = consume_time
        if predict_expr == targ_exprs[idx]:
            print("Predict \033[1;32m True \033[0m")
            print("Predict True", file=f)
            correct_predict_count += 1
        else:
            z3Result = verify_equivalent(predict_expr, targ_exprs[idx])
            if z3Result != 'unsat':
                print("Predict \033[1;31m False \033[0m")
                print("Predict False", file=f)
                wrong_predict_statistic.append([mba_exprs[idx], targ_exprs[idx], predict_expr])
            else:
                z3_verify_correct_count += 1
                print("Predict \033[1;33m Z3 True \033[0m")
                print("Predict Z3 True", file=f)
        print("Time = %.4f" % consume_time, file=f)
        print("", file=f)
    print("#Correct predict: %d/%d" % (correct_predict_count, test_count), file=f)
    print("#False predict true Z3:", z3_verify_correct_count, file=f)
    print("#Correct rate: %.4f" % ((correct_predict_count+z3_verify_correct_count)/test_count), file=f)
    print("Average solve time: %.4f" % (time_sum / test_count), file=f)
    print("Maximum solve time: %.4f" % (max_time), file=f)
    print("Minimum solve time: %.4f" % (min_time), file=f)
    print("Average result length: %.4f" % (total_len/test_count))

    pd.DataFrame(wrong_predict_statistic).to_csv("wrong_predict_statistic.csv", mode='w+', header=False, index=False)
    
    f.close()
Exemplo n.º 7
0
    img = tf.io.read_file(path)
    img = tf.io.decode_jpeg(img, channels=args.img_channels)
    if not img_width:
        img_shape = tf.shape(img)
        scale_factor = img_height / img_shape[0]
        img_width = scale_factor * tf.cast(img_shape[1], tf.float64)
        img_width = tf.cast(img_width, tf.int32)
    img = tf.image.resize(img, (img_height, img_width)) / 255.0
    return img


with open(args.table_path, 'r') as f:
    table = [char.strip() for char in f]

model = keras.models.load_model(args.model, compile=False)

decoder = Decoder(table)

p = Path(args.images)
if p.is_dir():
    img_paths = p.iterdir()
else:
    img_paths = [p]

for img_path in img_paths:
    img = read_img_and_preprocess(str(img_path))
    img = tf.expand_dims(img, 0)
    y_pred = model(img)
    g_decode = decoder.decode(y_pred, method='greedy')[0]
    b_decode = decoder.decode(y_pred, method='beam_search')[0]
    print(f'Path: {img_path}, greedy: {g_decode}, beam search: {b_decode}')
Exemplo n.º 8
0
def read_img_and_preprocess(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_jpeg(img, channels=args.img_channels)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, (32, args.img_width))
    return img


p = Path(args.images)
if p.is_dir():
    img_paths = p.iterdir()
    imgs = [read_img_and_preprocess(str(x)) for x in p.iterdir()]
    imgs = tf.stack(imgs)
else:
    img_paths = [p]
    img = read_img_and_preprocess(str(p))
    imgs = tf.expand_dims(img, 0)

with open(args.table_path, 'r') as f:
    inv_table = [char.strip() for char in f]

model = keras.models.load_model(args.model, compile=False)

decoder = Decoder(inv_table)

y_pred = model(imgs)
for path, g_pred, b_pred in zip(img_paths,
                                decoder.decode(y_pred, method='greedy'),
                                decoder.decode(y_pred, method='beam_search')):
    print('Path: {}, greedy: {}, beam search: {}'.format(path, g_pred, b_pred))
Exemplo n.º 9
0
def predict_one_image(retinanet, image):
    # 对图像进行归一化
    image = image.astype(np.float32) / 255.0
    print('source image shape:', image.shape)

    # 将图像调整为适合输入的尺寸,具体细节解释可以查阅'utils.py'的'class Resizer(object)'
    min_side, max_side = 400, 800
    rows, cols, cns = image.shape
    smallest_side = min(rows, cols)
    scale = min_side / smallest_side
    largest_side = max(rows, cols)
    if largest_side * scale > max_side:
        scale = max_side / largest_side
    image = cv2.resize(image,
                       (int(round(cols * scale)), int(round(rows * scale))))
    print('resize image shape:', image.shape)
    rows, cols, cns = image.shape
    pad_w = 32 - rows % 32
    pad_h = 32 - cols % 32
    net_input = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32)
    net_input[:rows, :cols, :] = image.astype(np.float32)

    # 将 net_input 调整为可以输入 RetinaNet 的格式
    net_input = torch.Tensor(net_input)
    net_input = net_input.unsqueeze(dim=0)
    net_input = net_input.permute(0, 3, 1, 2)
    print('RetinaNet input size:', net_input.size())

    anchor = Anchor()
    decoder = Decoder()

    if cuda:
        net_input = net_input.cuda()
        anchor = anchor.cuda()
        decoder = decoder.cuda()

    total_anchors = anchor(net_input)
    print('create anchor number:', total_anchors.size()[0])
    classification, localization = retinanet(net_input)

    pred_boxes = decoder(total_anchors, localization)

    # pred_boxes中的边框,有可能会出现在图像边界以外,需要将其拉回
    height, width, _ = image.shape
    pred_boxes[:, 0] = torch.clamp(pred_boxes[:, 0], min=0)
    pred_boxes[:, 1] = torch.clamp(pred_boxes[:, 1], min=0)
    pred_boxes[:, 2] = torch.clamp(pred_boxes[:, 2], max=width)
    pred_boxes[:, 3] = torch.clamp(pred_boxes[:, 3], max=height)

    # classification: [1, -1, 80]
    # torch.max(classification, dim=2, keepdim=True): [(1, -1, 1), (1, -1, 1)]
    # scores: [1, -1, 1], 所有anchor对应的置信度最大的类别id
    scores, ss = torch.max(classification, dim=2, keepdim=True)

    scores_over_thresh = (scores > 0.05)[0, :, 0]  # [True or False]
    if scores_over_thresh.sum() == 0:
        # no boxes to NMS, just return
        nms_scores = torch.zeros(0)
        nms_cls = torch.zeros(0)
        nms_boxes = torch.zeros(0, 4)
    else:
        # 提取最大置信度超过阈值的 anchor 的 classification
        classification = classification[:, scores_over_thresh, :]
        # 提取最大置信度超过阈值的 anchor 的 pred_boxes
        pred_boxes = pred_boxes[scores_over_thresh, :]
        # 提取最大置信度超过阈值的 anchor 的 scores
        scores = scores[:, scores_over_thresh, :]

        nms_ind = nms(pred_boxes[:, :], scores[0, :, 0], 0.5)

        nms_scores, nms_cls = classification[0, nms_ind, :].max(dim=1)
        nms_boxes = pred_boxes[nms_ind, :]

    print('Predict bounding boxes number:', nms_scores.size()[0])
    bounding_boxes = [nms_scores, nms_cls, nms_boxes]

    imshow_result(image, bounding_boxes)
Exemplo n.º 10
0
def train(hidden_size, learning_rate, l2_regularization, n_disc,
          generated_mse_imbalance, generated_loss_imbalance,
          likelihood_imbalance):
    # train_set = np.load("../../Trajectory_generate/dataset_file/train_x_.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/test_x.npy").reshape(-1, 6, 60)
    # test_set = np.load("../../Trajectory_generate/dataset_file/validate_x_.npy").reshape(-1, 6, 60)

    train_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_train_.npy').reshape(
            -1, 6, 30)
    # test_set = np.load('../../Trajectory_generate/dataset_file/HF_validate_.npy').reshape(-1, 6, 30)
    test_set = np.load(
        '../../Trajectory_generate/dataset_file/HF_test_.npy').reshape(
            -1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    # train_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 3

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50

    # hidden_size = 2 ** (int(hidden_size))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # n_disc = int(n_disc)
    # generated_mse_imbalance = 10 ** generated_mse_imbalance
    # generated_loss_imbalance = 10 ** generated_loss_imbalance
    # likelihood_imbalance = 10 ** likelihood_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(
        previous_visit, predicted_visit))

    print(
        'hidden_size---{}---learning_rate---{}---l2_regularization---{}---n_disc---{}'
        'generated_mse_imbalance---{}---generated_loss_imbalance---{}---'
        'likelihood_imbalance---{}'.format(hidden_size, learning_rate,
                                           l2_regularization, n_disc,
                                           generated_mse_imbalance,
                                           generated_loss_imbalance,
                                           likelihood_imbalance))
    encode_share = Encoder(hidden_size=hidden_size)
    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    hawkes_process = HawkesProcess()
    discriminator = Discriminator(previous_visit=previous_visit,
                                  predicted_visit=predicted_visit,
                                  hidden_size=hidden_size)

    logged = set()
    max_loss = 0.001
    max_pace = 0.0001
    count = 0
    loss = 0
    optimizer_generation = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    optimizer_discriminator = tf.keras.optimizers.RMSprop(
        learning_rate=learning_rate)
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)
        batch = input_train.shape[0]

        with tf.GradientTape() as gen_tape, tf.GradientTape(
                persistent=True) as disc_tape:
            generated_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            probability_likelihood = tf.zeros(shape=[batch, 0, 1])
            for predicted_visit_ in range(predicted_visit):
                sequence_last_time = input_x_train[:, previous_visit +
                                                   predicted_visit_ - 1, :]
                for previous_visit_ in range(previous_visit +
                                             predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(
                            tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share(
                        [sequence_time, encode_c, encode_h])
                context_state = encode_h

                if predicted_visit_ == 0:
                    decode_c = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))
                    decode_h = tf.Variable(
                        tf.zeros(shape=[batch, hidden_size]))

                current_time_index_shape = tf.ones(
                    shape=[previous_visit + predicted_visit_])
                intensity_value, likelihood = hawkes_process(
                    [input_t_train, current_time_index_shape])
                probability_likelihood = tf.concat(
                    (probability_likelihood,
                     tf.reshape(likelihood, [batch, -1, 1])),
                    axis=1)

                generated_next_visit, decode_c, decode_h = decoder_share([
                    sequence_last_time, context_state, decode_c,
                    decode_h * intensity_value
                ])
                generated_trajectory = tf.concat(
                    (generated_trajectory,
                     tf.reshape(generated_next_visit,
                                [batch, -1, feature_dims])),
                    axis=1)

            d_real_pre_, d_fake_pre_ = discriminator(input_x_train,
                                                     generated_trajectory)
            d_real_pre_loss = cross_entropy(tf.ones_like(d_real_pre_),
                                            d_real_pre_)
            d_fake_pre_loss = cross_entropy(tf.zeros_like(d_fake_pre_),
                                            d_fake_pre_)
            d_loss = d_real_pre_loss + d_fake_pre_loss

            gen_loss = cross_entropy(tf.ones_like(d_fake_pre_), d_fake_pre_)
            generated_mse_loss = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            likelihood_loss = tf.reduce_mean(probability_likelihood)

            loss += generated_mse_loss * generated_mse_imbalance + likelihood_loss * likelihood_imbalance + \
                    gen_loss * generated_loss_imbalance

            for weight in discriminator.trainable_variables:
                d_loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

            for weight in hawkes_process.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)
                variables.append(weight)

        for disc in range(n_disc):
            gradient_disc = disc_tape.gradient(
                d_loss, discriminator.trainable_variables)
            optimizer_discriminator.apply_gradients(
                zip(gradient_disc, discriminator.trainable_variables))

        gradient_gen = gen_tape.gradient(loss, variables)
        optimizer_generation.apply_gradients(zip(gradient_gen, variables))

        if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
            logged.add(train_set.epoch_completed)
            loss_pre = generated_mse_loss

            mse_generated = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_train[:, previous_visit:previous_visit +
                                  predicted_visit, :], generated_trajectory))

            loss_diff = loss_pre - mse_generated

            if mse_generated > max_loss:
                count = 0
            else:
                if loss_diff > max_pace:
                    count = 0
                else:
                    count += 1
            if count > 9:
                break

            input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
            input_t_test = tf.cast(test_set[:, :, 0], tf.float32)

            batch_test = test_set.shape[0]
            generated_trajectory_test = tf.zeros(
                shape=[batch_test, 0, feature_dims])
            for predicted_visit_ in range(predicted_visit):
                for previous_visit_ in range(previous_visit):
                    sequence_time_test = input_x_test[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))
                        encode_h_test = tf.Variable(
                            tf.zeros(shape=[batch_test, hidden_size]))

                    encode_c_test, encode_h_test = encode_share(
                        [sequence_time_test, encode_c_test, encode_h_test])

                if predicted_visit_ != 0:
                    for i in range(predicted_visit_):
                        encode_c_test, encode_h_test = encode_share([
                            generated_trajectory_test[:, i, :], encode_c_test,
                            encode_h_test
                        ])

                context_state_test = encode_h_test

                if predicted_visit_ == 0:
                    decode_c_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    decode_h_test = tf.Variable(
                        tf.zeros(shape=[batch_test, hidden_size]))
                    sequence_last_time_test = input_x_test[:, previous_visit +
                                                           predicted_visit_ -
                                                           1, :]

                current_time_index_shape = tf.ones(
                    [previous_visit + predicted_visit_])
                intensity_value, likelihood = hawkes_process(
                    [input_t_test, current_time_index_shape])
                generated_next_visit, decode_c_test, decode_h_test = decoder_share(
                    [
                        sequence_last_time_test, context_state_test,
                        decode_c_test, decode_h_test * intensity_value
                    ])
                generated_trajectory_test = tf.concat(
                    (generated_trajectory_test,
                     tf.reshape(generated_next_visit,
                                [batch_test, -1, feature_dims])),
                    axis=1)
                sequence_last_time_test = generated_next_visit

            mse_generated_test = tf.reduce_mean(
                tf.keras.losses.mse(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))
            mae_generated_test = tf.reduce_mean(
                tf.keras.losses.mae(
                    input_x_test[:, previous_visit:previous_visit +
                                 predicted_visit, :],
                    generated_trajectory_test))

            r_value_all = []
            for patient in range(batch_test):
                r_value = 0.0
                for feature in range(feature_dims):
                    x_ = input_x_test[patient, previous_visit:,
                                      feature].numpy().reshape(
                                          predicted_visit, 1)
                    y_ = generated_trajectory_test[patient, :,
                                                   feature].numpy().reshape(
                                                       predicted_visit, 1)
                    r_value += DynamicTimeWarping(x_, y_)
                r_value_all.append(r_value / 29.0)

            print(
                '------epoch{}------mse_loss{}----mae_loss{}------predicted_r_value---{}--'
                '-count  {}'.format(train_set.epoch_completed,
                                    mse_generated_test, mae_generated_test,
                                    np.mean(r_value_all), count))

            # r_value_all = []
            # p_value_all = []
            # r_value_spearman = []
            # r_value_kendalltau = []
            # for visit in range(predicted_visit):
            #     for feature in range(feature_dims):
            #         x_ = input_x_test[:, previous_visit+visit, feature]
            #         y_ = generated_trajectory_test[:, visit, feature]
            #         r_value_ = stats.pearsonr(x_, y_)
            #         r_value_spearman_ = stats.spearmanr(x_, y_)
            #         r_value_kendalltau_ = stats.kendalltau(x_, y_)
            #         if not np.isnan(r_value_[0]):
            #             r_value_all.append(np.abs(r_value_[0]))
            #             p_value_all.append(np.abs(r_value_[1]))
            #         if not np.isnan(r_value_spearman_[0]):
            #             r_value_spearman.append(np.abs(r_value_spearman_[0]))
            #         if not np.isnan(r_value_kendalltau_[0]):
            #             r_value_kendalltau.append(np.abs(r_value_kendalltau_[0]))
            # print('------epoch{}------mse_loss{}----mae_loss{}------predicted_r_value---{}--'
            #       'r_value_spearman---{}---r_value_kendalltau---{}--count  {}'.format(train_set.epoch_completed,
            #                                                                           mse_generated_test,

# 																		  mae_generated_test,
#                                                                           np.mean(r_value_all),
#                                                                           np.mean(r_value_spearman),
#                                                                           np.mean(r_value_kendalltau),
#                                                                           count))

    tf.compat.v1.reset_default_graph()
    return mse_generated_test, mae_generated_test, np.mean(r_value_all)
def train(hidden_size, l2_regularization, learning_rate, generated_imbalance, likelihood_imbalance):
    train_set = np.load("../../Trajectory_generate/dataset_file/HF_train_.npy").reshape(-1, 6, 30)
    test_set = np.load("../../Trajectory_generate/dataset_file/HF_test_.npy").reshape(-1, 6, 30)
    # test_set = np.load("../../Trajectory_generate/dataset_file/HF_validate_.npy").reshape(-1, 6, 30)

    # train_set = np.load("../../Trajectory_generate/dataset_file/mimic_train_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_test_x_.npy").reshape(-1, 6, 37)
    # test_set = np.load("../../Trajectory_generate/dataset_file/mimic_validate_.npy").reshape(-1, 6, 37)

    # sepsis mimic dataset
    # train_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_train.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_test.npy').reshape(-1, 13, 40)
    # test_set = np.load('../../Trajectory_generate/dataset_file/sepsis_mimic_validate.npy').reshape(-1, 13, 40)

    previous_visit = 3
    predicted_visit = 3

    feature_dims = train_set.shape[2] - 1

    train_set = DataSet(train_set)
    train_set.epoch_completed = 0
    batch_size = 64
    epochs = 50

    # hidden_size = 2 ** (int(hidden_size))
    # learning_rate = 10 ** learning_rate
    # l2_regularization = 10 ** l2_regularization
    # generated_imbalance = 10 ** generated_imbalance
    # likelihood_imbalance = 10 ** likelihood_imbalance

    print('previous_visit---{}---predicted_visit----{}-'.format(previous_visit, predicted_visit))

    print('hidden_size----{}---'
          'l2_regularization---{}---'
          'learning_rate---{}---'
          'generated_imbalance---{}---'
          'likelihood_imbalance---{}'.
          format(hidden_size, l2_regularization, learning_rate,
                 generated_imbalance, likelihood_imbalance))

    decoder_share = Decoder(hidden_size=hidden_size, feature_dims=feature_dims)
    encode_share = Encoder(hidden_size=hidden_size)
    hawkes_process = HawkesProcess()

    logged = set()
    max_loss = 0.01
    max_pace = 0.001
    loss = 0

    count = 0
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    while train_set.epoch_completed < epochs:
        input_train = train_set.next_batch(batch_size=batch_size)
        batch = input_train.shape[0]
        input_x_train = tf.cast(input_train[:, :, 1:], tf.float32)
        input_t_train = tf.cast(input_train[:, :, 0], tf.float32)

        with tf.GradientTape() as tape:
            predicted_trajectory = tf.zeros(shape=[batch, 0, feature_dims])
            likelihood_all = tf.zeros(shape=[batch, 0, 1])
            for predicted_visit_ in range(predicted_visit):
                sequence_time_last_time = input_x_train[:, previous_visit+predicted_visit_-1, :]
                for previous_visit_ in range(previous_visit+predicted_visit_):
                    sequence_time = input_x_train[:, previous_visit_, :]
                    if previous_visit_ == 0:
                        encode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                        encode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))

                    encode_c, encode_h = encode_share([sequence_time, encode_c, encode_h])
                context_state = encode_h

                if predicted_visit_ == 0:
                    decode_c = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                    decode_h = tf.Variable(tf.zeros(shape=[batch, hidden_size]))
                current_time_index_shape = tf.ones(shape=[predicted_visit_+previous_visit])
                condition_intensity, likelihood = hawkes_process([input_t_train, current_time_index_shape])
                likelihood_all = tf.concat((likelihood_all, tf.reshape(likelihood, [batch, -1, 1])), axis=1)
                generated_next_visit, decode_c, decode_h = decoder_share([sequence_time_last_time, context_state, decode_c, decode_h*condition_intensity])
                predicted_trajectory = tf.concat((predicted_trajectory, tf.reshape(generated_next_visit, [batch, -1, feature_dims])), axis=1)

            mse_generated_loss = tf.reduce_mean(tf.keras.losses.mse(input_x_train[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory))
            mae_generated_loss = tf.reduce_mean(tf.keras.losses.mae(input_x_train[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory))
            likelihood_loss = tf.reduce_mean(likelihood_all)

            loss += mse_generated_loss * generated_imbalance + likelihood_loss * likelihood_imbalance

            variables = [var for var in encode_share.trainable_variables]
            for weight in encode_share.trainable_variables:
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in decoder_share.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            for weight in hawkes_process.trainable_variables:
                variables.append(weight)
                loss += tf.keras.regularizers.l2(l2_regularization)(weight)

            gradient = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradient, variables))

            if train_set.epoch_completed % 1 == 0 and train_set.epoch_completed not in logged:
                logged.add(train_set.epoch_completed)

                loss_pre = mse_generated_loss
                mse_generated_loss = tf.reduce_mean(
                    tf.keras.losses.mse(input_x_train[:, previous_visit:previous_visit + predicted_visit, :],
                                        predicted_trajectory))

                loss_diff = loss_pre - mse_generated_loss

                if max_loss < mse_generated_loss:
                    count = 0
                else:
                    if max_pace < loss_diff:
                        count = 0

                    else:
                        count += 1
                if count > 9:
                    break

                input_x_test = tf.cast(test_set[:, :, 1:], tf.float32)
                input_t_test = tf.cast(test_set[:, :, 0], tf.float32)

                batch_test = input_x_test.shape[0]
                predicted_trajectory_test = tf.zeros(shape=[batch_test, 0, feature_dims])
                for predicted_visit_ in range(predicted_visit):
                    for previous_visit_ in range(previous_visit):
                        sequence_time_test = input_x_test[:, previous_visit_, :]
                        if previous_visit_ == 0:
                            encode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                            encode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        encode_c_test, encode_h_test = encode_share([sequence_time_test, encode_c_test, encode_h_test])

                    if predicted_visit_ != 0:
                        for i in range(predicted_visit_):
                            encode_c, encode_h_test = encode_share([predicted_trajectory_test[:, i, :], encode_c_test, encode_h_test])
                    context_state_test = encode_h_test

                    if predicted_visit_ == 0:
                        decode_c_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        decode_h_test = tf.Variable(tf.zeros(shape=[batch_test, hidden_size]))
                        sequence_time_last_time_test = input_x_test[:, predicted_visit_+previous_visit-1, :]

                    current_time_index_shape_test = tf.ones(shape=[previous_visit+predicted_visit_])
                    condition_intensity_test, likelihood_test = hawkes_process([input_t_test, current_time_index_shape_test])

                    sequence_next_visit_test, decode_c_test, decode_h_test = decoder_share([sequence_time_last_time_test, context_state_test, decode_c_test, decode_h_test*condition_intensity_test])
                    predicted_trajectory_test = tf.concat((predicted_trajectory_test, tf.reshape(sequence_next_visit_test, [batch_test, -1, feature_dims])), axis=1)
                    sequence_time_last_time_test = sequence_next_visit_test

                mse_generated_loss_test = tf.reduce_mean(tf.keras.losses.mse(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))
                mae_generated_loss_test = tf.reduce_mean(tf.keras.losses.mae(input_x_test[:, previous_visit:previous_visit+predicted_visit, :], predicted_trajectory_test))

                r_value_all = []
                for patient in range(batch_test):
                    r_value = 0.0
                    for feature in range(feature_dims):
                        x_ = input_x_test[patient, previous_visit:, feature].numpy().reshape(predicted_visit, 1)
                        y_ = predicted_trajectory_test[patient, :, feature].numpy().reshape(predicted_visit, 1)
                        r_value += DynamicTimeWarping(x_, y_)
                    r_value_all.append(r_value / 29.0)
                print("epoch  {}---train_mse_generate {}- - "
                      "mae_generated_loss--{}--test_mse {}--test_mae  "
                      "{}--r_value {}-count {}".format(train_set.epoch_completed,
                                                       mse_generated_loss,
                                                       mae_generated_loss,
                                                       mse_generated_loss_test,
                                                       mae_generated_loss_test,
                                                       np.mean(r_value_all),
                                                       count))


                # r_value_all = []
                # p_value_all = []
                # r_value_spearman_all = []
                # r_value_kendall_all = []
                # for visit in range(predicted_visit):
                #     for feature in range(feature_dims):
                #         x_ = input_x_test[:, previous_visit+visit, feature]
                #         y_ = predicted_trajectory_test[:, visit, feature]
                #         r_value_ = stats.pearsonr(x_, y_)
                #         r_value_spearman = stats.spearmanr(x_, y_)
                #         r_value_kendall = stats.kendalltau(x_, y_)
                #         if not np.isnan(r_value_[0]):
                #             r_value_all.append(np.abs(r_value_[0]))
                #             p_value_all.append(np.abs(r_value_[1]))
                #         if not np.isnan(r_value_spearman[0]):
                #             r_value_spearman_all.append(np.abs(r_value_spearman[0]))
                #         if not np.isnan(r_value_kendall[0]):
                #             r_value_kendall_all.append(np.abs(r_value_kendall[0]))

                # print("epoch  {}---train_mse_generate {}- - "
                #       "mae_generated_loss--{}--test_mse {}--test_mae  "
                #       "{}----r_value {}--r_spearman---{}-"
                #       "r_kendall---{}    -count {}".format(train_set.epoch_completed,
                #                                            mse_generated_loss,
                #                                            mae_generated_loss,
                #                                            mse_generated_loss_test,
                #                                            mae_generated_loss_test,
                #                                            np.mean(r_value_all),
                #                                            np.mean(r_value_spearman_all),
                #                                            np.mean(r_value_kendall_all),
                #                                            count))
    tf.compat.v1.reset_default_graph()
    return mse_generated_loss_test, mae_generated_loss_test, np.mean(r_value_all)