def test_serialize_deserialize_with_batching(self): from steves_utils.ORACLE.simple_oracle_dataset_factory import Simple_ORACLE_Dataset_Factory from steves_utils.ORACLE.utils import ORIGINAL_PAPER_SAMPLES_PER_CHUNK, ALL_SERIAL_NUMBERS ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=[8], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:3]) for e in ds.batch(1000): record = example_to_tf_record(e) serialized = record.SerializeToString() deserialized = serialized_tf_record_to_example(serialized) self.assertTrue( np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy())) # Just a dumb sanity check self.assertFalse( np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy()[0]))
def get_less_limited_oracle(): """test loss: 0.05208379030227661 , test acc: 0.16599488258361816""" TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH=1000 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6] ) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache(os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "less_limited_oracle")) # Prime the cache for e in ds.batch(1000): pass # print("Buffer primed. Comment this out next time") # sys.exit(1) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train+num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) val_ds = val_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) test_ds = test_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(100) val_ds = val_ds.prefetch(100) test_ds = test_ds.prefetch(100) return train_ds, val_ds, test_ds