Ejemplo n.º 1
0
def training(file_path, epochs, noise_std, do_gif=False):
    def create_gif(path, title, step=1):
        images = []
        for i, filename in enumerate(os.listdir(path)):
            if i % step == 0:
                images.append(imageio.imread(os.path.join(path, filename)))
        imageio.mimsave(os.path.join(ENC_PLOT_DATA_PATH, title),
                        images,
                        duration=0.1)

        shutil.rmtree(path)

    def get_data(file_path):
        # TODO: try using 2 proteins: train on the first one test on the other one
        data = read_data(file_path, normalize='std')
        x_train, x_test = train_test_split(data,
                                           test_size=0.1,
                                           random_state=42)
        return x_train, x_test

    def add_noise(x, std):
        return x + normal(0, std, size=x.shape)

    for file in os.listdir(file_path):
        if not (file.endswith(".csv")):
            continue
        name = file.replace(".csv", "")
        print("Analysing %s" % name)
        x_train, x_test = get_data(os.path.join(file_path, file))
        x_train_noisy, x_test_noisy = add_noise(
            x_train, std=noise_std), add_noise(x_test, std=noise_std)
        network_model = Sequential(input_size=x_train.shape[1], encoded_size=2)
        network_model.add(
            Bidirectional(LSTM(10, return_sequences=True),
                          input_shape=(5, 10)))
        network_model.add(Bidirectional(LSTM(10)))
        network_model.add(Dense(5))
        network_model.add(Activation('softmax'))
        network_model.compile(loss='categorical_crossentropy',
                              optimizer='rmsprop')

        data = read_data(os.path.join(file_path, file), normalize='std')

        # encoded = AutoEncoder(load_model("encoding/models/proteins-autoencoder_6EQE_Angles_And_RSA.h5")).encode(data)
        # df = pd.DataFrame(encoded)
        # df.to_csv("encoding/models/6EQE_Angles_And_RSA_result.csv", index=False, header=False)

        network_model.plot(data,
                           title=name,
                           file=os.path.join(ENC_PLOT_DATA_PATH, name + "_" +
                                             str(epochs) + ".jpg"))