Esempio n. 1
0
def loading_data(cacheFile):

    if os.path.exists(cacheFile):

        with h5py.File(cacheFile, "r") as hf:

            X = hf["X_data"][:]
            seq_len = hf["seq_len"][:]
            label = [hf["Y_ctc/index"][:], hf["Y_ctc/value"][:], hf["Y_ctc/shape"]]
            label_vec = hf["Y_vec"][:]
            label_seg = hf["Y_seg"][:]
    else:

        print("Now caching the data ... ")
        
        ds = read_raw_data_sets(FLAGS.data_dir, FLAGS.train_cache,FLAGS.sequence_len, FLAGS.k_mer)
        X, seq_len, label, label_vec, label_seg = ds.next_batch(ds._reads_n)

        with h5py.File(cacheFile, "w") as hf:

            hf.create_dataset("X_data", data=X)
            hf.create_dataset("seq_len", data=seq_len)

            hf.create_dataset("Y_vec", data=label_vec)
            hf.create_dataset("Y_seg", data=label_seg)

            hf.create_dataset("Y_ctc/index", data=label[0])
            hf.create_dataset("Y_ctc/value", data=label[1])
            hf.create_dataset("Y_ctc/shape", data=label[2])

        print("Done!")

    return X, seq_len, label, label_vec, label_seg
Esempio n. 2
0
def generate_train_valid_datasets():

    if FLAGS.read_cache:
        train_ds = read_cache_dataset(FLAGS.train_cache)
        if FLAGS.validation is not None:
            valid_ds = read_cache_dataset(FLAGS.valid_cache)
        else:
            valid_ds = train_ds
        if train_ds.event.shape[1] != FLAGS.sequence_len:
            raise ValueError(
                "The event length of training cached dataset %d is inconsistent with given sequene_len %d"
                % (train_ds.event.shape()[1], FLAGS.sequence_len))
        if valid_ds.event.shape[1] != FLAGS.sequence_len:
            raise ValueError(
                "The event length of training cached dataset %d is inconsistent with given sequene_len %d"
                % (valid_ds.event.shape()[1], FLAGS.sequence_len))
        return train_ds, valid_ds
    sys.stdout.write("Begin reading training dataset.\n")

    train_ds = read_raw_data_sets(FLAGS.data_dir, FLAGS.train_cache,\
            FLAGS.sequence_len, FLAGS.k_mer)
    """
    train_ds = read_tfrecord(FLAGS.data_dir, 
                             FLAGS.tfrecord, 
                             FLAGS.train_cache,
                             FLAGS.sequence_len, 
                             k_mer=FLAGS.k_mer,
                             max_segments_num=FLAGS.segments_num)
    """
    sys.stdout.write("Begin reading validation dataset.\n")

    if FLAGS.validation is not None:
        valid_ds = read_tfrecord(FLAGS.data_dir,
                                 FLAGS.validation,
                                 FLAGS.valid_cache,
                                 FLAGS.sequence_len,
                                 k_mer=FLAGS.k_mer,
                                 max_segments_num=FLAGS.segments_num)
    else:
        valid_ds = train_ds

    return train_ds, valid_ds