Exemplo n.º 1
0
    def test_serialize_deserialize_with_batching(self):
        from steves_utils.ORACLE.simple_oracle_dataset_factory import Simple_ORACLE_Dataset_Factory
        from steves_utils.ORACLE.utils import ORIGINAL_PAPER_SAMPLES_PER_CHUNK, ALL_SERIAL_NUMBERS

        ds, cardinality = Simple_ORACLE_Dataset_Factory(
            ORIGINAL_PAPER_SAMPLES_PER_CHUNK,
            runs_to_get=[1],
            distances_to_get=[8],
            serial_numbers_to_get=ALL_SERIAL_NUMBERS[:3])

        for e in ds.batch(1000):
            record = example_to_tf_record(e)
            serialized = record.SerializeToString()
            deserialized = serialized_tf_record_to_example(serialized)

            self.assertTrue(
                np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy()))

            # Just a dumb sanity check
            self.assertFalse(
                np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy()[0]))
Exemplo n.º 2
0
def get_less_limited_oracle():
    """test loss: 0.05208379030227661 , test acc: 0.16599488258361816"""
    TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2)
    BATCH=1000
    RANGE   = len(ALL_SERIAL_NUMBERS)

    
    ds, cardinality = Simple_ORACLE_Dataset_Factory(
        ORIGINAL_PAPER_SAMPLES_PER_CHUNK, 
        runs_to_get=[1],
        distances_to_get=ALL_DISTANCES_FEET[:1],
        serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6]
    )

    print("Total Examples:", cardinality)
    print("That's {}GB of data (at least)".format( cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024))
    input("Pres Enter to continue")
    num_train = int(cardinality * TRAIN_SPLIT)
    num_val = int(cardinality * VAL_SPLIT)
    num_test = int(cardinality * TEST_SPLIT)

    ds = ds.shuffle(cardinality)
    ds = ds.cache(os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "less_limited_oracle"))

    # Prime the cache
    for e in ds.batch(1000):
        pass

    # print("Buffer primed. Comment this out next time")
    # sys.exit(1)

    train_ds = ds.take(num_train)
    val_ds = ds.skip(num_train).take(num_val)
    test_ds = ds.skip(num_train+num_val).take(num_test)

    train_ds = train_ds.map(
        lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True
    )

    val_ds = val_ds.map(
        lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True
    )

    test_ds = test_ds.map(
        lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True
    )
    
    train_ds = train_ds.batch(BATCH)
    val_ds   = val_ds.batch(BATCH)
    test_ds  = test_ds.batch(BATCH)

    train_ds = train_ds.prefetch(100)
    val_ds   = val_ds.prefetch(100)
    test_ds  = test_ds.prefetch(100)

    return train_ds, val_ds, test_ds