def test_serialize_deserialize_with_batching(self): from steves_utils.ORACLE.simple_oracle_dataset_factory import Simple_ORACLE_Dataset_Factory from steves_utils.ORACLE.utils import ORIGINAL_PAPER_SAMPLES_PER_CHUNK, ALL_SERIAL_NUMBERS ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=[8], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:3]) for e in ds.batch(1000): record = example_to_tf_record(e) serialized = record.SerializeToString() deserialized = serialized_tf_record_to_example(serialized) self.assertTrue( np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy())) # Just a dumb sanity check self.assertFalse( np.array_equal(e["IQ"].numpy(), deserialized["IQ"].numpy()[0]))
def get_limited_oracle(): TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH=1000 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=[8], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:3] ) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train+num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) val_ds = val_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) test_ds = test_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) return train_ds, val_ds, test_ds
def setUpClass(self): self.pile_path = os.path.join(SCRATCH_DIR, "piles") self.output_path = os.path.join(SCRATCH_DIR, "output") self.runs_to_get = [1] self.distances_to_get = ALL_DISTANCES_FEET[:3] self.serial_numbers_to_get = ALL_SERIAL_NUMBERS[:3] self.train_val_test_splits = (0.6, 0.2, 0.2) self.shuffler = Dataset_Shuffler( num_samples_per_chunk=ORIGINAL_PAPER_SAMPLES_PER_CHUNK, output_batch_size=1000, num_piles=5, output_format_str= "shuffled_batchSize-{batch_size}_part-{part}.tfrecord_ds", output_max_file_size_MB=5, pile_dir=self.pile_path, output_dir=self.output_path, seed=1337, runs_to_get=self.runs_to_get, distances_to_get=self.distances_to_get, serial_numbers_to_get=self.serial_numbers_to_get, ) clear_scrath_dir() self.shuffler.create_and_check_dirs() print("Write piles") self.shuffler.write_piles() print("shuffle") self.shuffler.shuffle_piles() self.simple_ds, self.cardinality = Simple_ORACLE_Dataset_Factory( num_samples_per_chunk=ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=self.runs_to_get, distances_to_get=self.distances_to_get, serial_numbers_to_get=self.serial_numbers_to_get, )
def get_windowed_less_limited_oracle(): TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH=500 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6] ) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * 2 * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) # input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache(os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "windowed_less_limited_oracle")) # Prime the cache # for e in ds.batch(1000): # pass # print("Buffer primed. Comment this out next time") # sys.exit(1) """ stride, repeat, batch 16,9,1000 good 16,9,500 good 32,5,1000 bad 32,5,500 good 4, 33, 500 bad 4, 33, 500 with shuffling bad (but slightly better) 1, 129, 500 with shuffling GOOD! """ STRIDE_SIZE=1 NUM_REPEATS=129 train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train+num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) val_ds = val_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) test_ds = test_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) train_ds = train_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, STRIDE_SIZE), # Somehow we get 9 frames from this [1,0,2] ), tf.repeat(tf.reshape(y, (1,RANGE)), repeats=NUM_REPEATS, axis=0) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) val_ds = val_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, STRIDE_SIZE), # Somehow we get 9 frames from this [1,0,2] ), tf.repeat(tf.reshape(y, (1,RANGE)), repeats=NUM_REPEATS, axis=0) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) test_ds = test_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, STRIDE_SIZE), # Somehow we get 9 frames from this [1,0,2] ), tf.repeat(tf.reshape(y, (1,RANGE)), repeats=NUM_REPEATS, axis=0) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) for e in train_ds.take(1): print(e) train_ds = train_ds.unbatch() val_ds = val_ds.unbatch() test_ds = test_ds.unbatch() train_ds = train_ds.shuffle(BATCH*NUM_REPEATS*3) val_ds = val_ds.shuffle(BATCH*NUM_REPEATS*3) test_ds = test_ds.shuffle(BATCH*NUM_REPEATS*3) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(10) val_ds = val_ds.prefetch(10) test_ds = test_ds.prefetch(10) return train_ds, val_ds, test_ds
def get_less_limited_oracle(): """test loss: 0.05208379030227661 , test acc: 0.16599488258361816""" TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH=1000 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6] ) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache(os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "less_limited_oracle")) # Prime the cache for e in ds.batch(1000): pass # print("Buffer primed. Comment this out next time") # sys.exit(1) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train+num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) val_ds = val_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) test_ds = test_ds.map( lambda x: (x["IQ"],tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(100) val_ds = val_ds.prefetch(100) test_ds = test_ds.prefetch(100) return train_ds, val_ds, test_ds
def get_windowed_less_limited_oracle(): """test loss: 0.018825260922312737 , test acc: 0.7720255255699158""" TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH = 1000 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6]) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * 2 * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache( os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "windowed_less_limited_oracle")) # Prime the cache # for e in ds.batch(1000): # pass # print("Buffer primed. Comment this out next time") # sys.exit(1) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train + num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) val_ds = val_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) test_ds = test_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) train_ds = train_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, 16 ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=9, axis=0 ) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) val_ds = val_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, 16 ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=9, axis=0 ) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) test_ds = test_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, 16 ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=9, axis=0 ) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) train_ds = train_ds.unbatch() val_ds = val_ds.unbatch() test_ds = test_ds.unbatch() train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(100) val_ds = val_ds.prefetch(100) test_ds = test_ds.prefetch(100) return train_ds, val_ds, test_ds
def __init__( self, output_batch_size, # num_piles, output_max_file_size_MB, pile_dir, output_dir, seed, num_samples_per_chunk, distances_to_get: List[int] = ALL_DISTANCES_FEET, serial_numbers_to_get: List[str] = ALL_SERIAL_NUMBERS, runs_to_get: List[int] = ALL_RUNS, output_format_str="shuffled_batchSize-{batch_size}_part-{part}.tfrecord_ds", fail_on_too_few_output_parts=True, ) -> None: self.output_batch_size = output_batch_size # self.num_piles = num_piles self.output_max_file_size_MB = output_max_file_size_MB self.pile_dir = pile_dir self.output_dir = output_dir self.seed = seed self.num_samples_per_chunk = num_samples_per_chunk self.distances_to_get = distances_to_get self.serial_numbers_to_get = serial_numbers_to_get self.runs_to_get = runs_to_get self.output_format_str = output_format_str self.ds, self.cardinality = Simple_ORACLE_Dataset_Factory( num_samples_per_chunk, runs_to_get=runs_to_get, distances_to_get=distances_to_get, serial_numbers_to_get=serial_numbers_to_get) self.total_ds_size_GB = self.cardinality * self.num_samples_per_chunk * 8 * 2 / 1024 / 1024 / 1024 self.num_piles = int(math.ceil(self.total_ds_size_GB)) # self.expected_pile_size_GB = self.total_ds_size_GB / self.num_piles # if self.expected_pile_size_GB > 5: # raise Exception("Expected pile size is too big: {}GB. Increase your num_piles".format(self.expected_pile_size_GB)) self.expected_num_parts = self.total_ds_size_GB * 1024 / output_max_file_size_MB if self.expected_num_parts < 15: if fail_on_too_few_output_parts: raise Exception( "Expected number of output parts is {}, need a minimum of 15" .format(self.expected_num_parts)) else: print( "Expected number of output parts is {}, need a minimum of 15" .format(self.expected_num_parts)) self.shuffler = steves_utils.dataset_shuffler.Dataset_Shuffler( input_ds=self.ds, one_example_to_tf_record_func=oracle_serialization. example_to_tf_record, one_example_from_serialized_tf_record_func=oracle_serialization. serialized_tf_record_to_example, batch_example_to_tf_record_func=oracle_serialization. example_to_tf_record, output_batch_size=output_batch_size, num_piles=self.num_piles, output_format_str=output_format_str, output_max_file_size_MB=output_max_file_size_MB, pile_dir=pile_dir, output_dir=output_dir, seed=seed)