Exemple #1
0
def get_windowed_less_limited_oracle():
    TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2)
    BATCH = 500
    RANGE = len(ALL_SERIAL_NUMBERS)

    chunk_size = 4 * ORIGINAL_PAPER_SAMPLES_PER_CHUNK
    STRIDE_SIZE = 1

    NUM_REPEATS = math.floor(
        (chunk_size - ORIGINAL_PAPER_SAMPLES_PER_CHUNK) / STRIDE_SIZE) + 1

    ds, cardinality = Simple_ORACLE_Dataset_Factory(
        chunk_size,
        runs_to_get=[1],
        distances_to_get=ALL_DISTANCES_FEET[:1],
        serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6])

    print("Total Examples:", cardinality)
    print("That's {}GB of data (at least)".format(cardinality * chunk_size *
                                                  2 * 8 / 1024 / 1024 / 1024))
    # input("Pres Enter to continue")
    num_train = int(cardinality * TRAIN_SPLIT)
    num_val = int(cardinality * VAL_SPLIT)
    num_test = int(cardinality * TEST_SPLIT)

    ds = ds.shuffle(cardinality)
    ds = ds.cache(
        os.path.join(steves_utils.utils.get_datasets_base_path(), "caches",
                     "windowed_less_limited_oracle"))

    # Prime the cache

    # for e in ds.batch(1000):
    #     pass

    # print("Buffer primed. Comment this out next time")
    # sys.exit(1)

    train_ds = ds.take(num_train)
    val_ds = ds.skip(num_train).take(num_val)
    test_ds = ds.skip(num_train + num_val).take(num_test)

    train_ds = train_ds.map(
        lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    val_ds = val_ds.map(lambda x:
                        (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
                        num_parallel_calls=tf.data.AUTOTUNE,
                        deterministic=True)

    test_ds = test_ds.map(lambda x:
                          (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=True)

    train_ds = train_ds.map(
        lambda x, y: (
            tf.transpose(
                tf.signal.
                frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, STRIDE_SIZE
                      ),  # Somehow we get 9 frames from this
                [1, 0, 2]),
            tf.repeat(tf.reshape(y, (1, RANGE)), repeats=NUM_REPEATS, axis=0
                      )  # Repeat our one hot tensor 9 times
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    val_ds = val_ds.map(
        lambda x, y: (
            tf.transpose(
                tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK,
                                ORIGINAL_PAPER_SAMPLES_PER_CHUNK
                                ),  # Somehow we get 9 frames from this
                [1, 0, 2]),
            tf.repeat(tf.reshape(y, (1, RANGE)),
                      repeats=math.floor(chunk_size /
                                         ORIGINAL_PAPER_SAMPLES_PER_CHUNK),
                      axis=0)  # Repeat our one hot tensor 9 times
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    test_ds = test_ds.map(
        lambda x, y: (
            tf.transpose(
                tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK,
                                ORIGINAL_PAPER_SAMPLES_PER_CHUNK
                                ),  # Somehow we get 9 frames from this
                [1, 0, 2]),
            tf.repeat(tf.reshape(y, (1, RANGE)),
                      repeats=math.floor(chunk_size /
                                         ORIGINAL_PAPER_SAMPLES_PER_CHUNK),
                      axis=0)  # Repeat our one hot tensor 9 times
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    # for e in train_ds:
    #     print(e)

    # for e in test_ds:
    #     print(e)

    # sys.exit(1)

    train_ds = train_ds.unbatch()
    val_ds = val_ds.unbatch()
    test_ds = test_ds.unbatch()

    train_ds = train_ds.shuffle(BATCH * NUM_REPEATS * 3)
    val_ds = val_ds.shuffle(BATCH * NUM_REPEATS * 3)
    test_ds = test_ds.shuffle(BATCH * NUM_REPEATS * 3)

    train_ds = train_ds.batch(BATCH)
    val_ds = val_ds.batch(BATCH)
    test_ds = test_ds.batch(BATCH)

    train_ds = train_ds.prefetch(100)
    val_ds = val_ds.prefetch(100)
    test_ds = test_ds.prefetch(100)

    return train_ds, val_ds, test_ds
Exemple #2
0
def get_less_limited_oracle():
    """test loss: 0.05208379030227661 , test acc: 0.16599488258361816"""
    TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2)
    BATCH = 500
    RANGE = len(ALL_SERIAL_NUMBERS)

    ds, cardinality = Simple_ORACLE_Dataset_Factory(
        ORIGINAL_PAPER_SAMPLES_PER_CHUNK,
        runs_to_get=[1],
        distances_to_get=ALL_DISTANCES_FEET[:1],
        serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6])

    print("Total Examples:", cardinality)
    print("That's {}GB of data (at least)".format(
        cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 /
        1024))
    input("Pres Enter to continue")
    num_train = int(cardinality * TRAIN_SPLIT)
    num_val = int(cardinality * VAL_SPLIT)
    num_test = int(cardinality * TEST_SPLIT)

    ds = ds.shuffle(cardinality)
    ds = ds.cache(
        os.path.join(steves_utils.utils.get_datasets_base_path(), "caches",
                     "less_limited_oracle"))

    # # Prime the cache
    # for e in ds.batch(1000):
    #     pass

    # print("Buffer primed. Comment this out next time")
    # sys.exit(1)

    train_ds = ds.take(num_train)
    val_ds = ds.skip(num_train).take(num_val)
    test_ds = ds.skip(num_train + num_val).take(num_test)

    train_ds = train_ds.map(
        lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    val_ds = val_ds.map(lambda x:
                        (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
                        num_parallel_calls=tf.data.AUTOTUNE,
                        deterministic=True)

    test_ds = test_ds.map(lambda x:
                          (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)),
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=True)

    train_ds = train_ds.batch(BATCH)
    val_ds = val_ds.batch(BATCH)
    test_ds = test_ds.batch(BATCH)

    train_ds = train_ds.prefetch(100)
    val_ds = val_ds.prefetch(100)
    test_ds = test_ds.prefetch(100)

    return train_ds, val_ds, test_ds