def get_windowed_less_limited_oracle(): TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH = 500 RANGE = len(ALL_SERIAL_NUMBERS) chunk_size = 4 * ORIGINAL_PAPER_SAMPLES_PER_CHUNK STRIDE_SIZE = 1 NUM_REPEATS = math.floor( (chunk_size - ORIGINAL_PAPER_SAMPLES_PER_CHUNK) / STRIDE_SIZE) + 1 ds, cardinality = Simple_ORACLE_Dataset_Factory( chunk_size, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6]) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format(cardinality * chunk_size * 2 * 8 / 1024 / 1024 / 1024)) # input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache( os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "windowed_less_limited_oracle")) # Prime the cache # for e in ds.batch(1000): # pass # print("Buffer primed. Comment this out next time") # sys.exit(1) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train + num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) val_ds = val_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) test_ds = test_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) train_ds = train_ds.map( lambda x, y: ( tf.transpose( tf.signal. frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, STRIDE_SIZE ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=NUM_REPEATS, axis=0 ) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) val_ds = val_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, ORIGINAL_PAPER_SAMPLES_PER_CHUNK ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=math.floor(chunk_size / ORIGINAL_PAPER_SAMPLES_PER_CHUNK), axis=0) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) test_ds = test_ds.map( lambda x, y: ( tf.transpose( tf.signal.frame(x, ORIGINAL_PAPER_SAMPLES_PER_CHUNK, ORIGINAL_PAPER_SAMPLES_PER_CHUNK ), # Somehow we get 9 frames from this [1, 0, 2]), tf.repeat(tf.reshape(y, (1, RANGE)), repeats=math.floor(chunk_size / ORIGINAL_PAPER_SAMPLES_PER_CHUNK), axis=0) # Repeat our one hot tensor 9 times ), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) # for e in train_ds: # print(e) # for e in test_ds: # print(e) # sys.exit(1) train_ds = train_ds.unbatch() val_ds = val_ds.unbatch() test_ds = test_ds.unbatch() train_ds = train_ds.shuffle(BATCH * NUM_REPEATS * 3) val_ds = val_ds.shuffle(BATCH * NUM_REPEATS * 3) test_ds = test_ds.shuffle(BATCH * NUM_REPEATS * 3) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(100) val_ds = val_ds.prefetch(100) test_ds = test_ds.prefetch(100) return train_ds, val_ds, test_ds
def get_less_limited_oracle(): """test loss: 0.05208379030227661 , test acc: 0.16599488258361816""" TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT = (0.6, 0.2, 0.2) BATCH = 500 RANGE = len(ALL_SERIAL_NUMBERS) ds, cardinality = Simple_ORACLE_Dataset_Factory( ORIGINAL_PAPER_SAMPLES_PER_CHUNK, runs_to_get=[1], distances_to_get=ALL_DISTANCES_FEET[:1], serial_numbers_to_get=ALL_SERIAL_NUMBERS[:6]) print("Total Examples:", cardinality) print("That's {}GB of data (at least)".format( cardinality * ORIGINAL_PAPER_SAMPLES_PER_CHUNK * 2 * 8 / 1024 / 1024 / 1024)) input("Pres Enter to continue") num_train = int(cardinality * TRAIN_SPLIT) num_val = int(cardinality * VAL_SPLIT) num_test = int(cardinality * TEST_SPLIT) ds = ds.shuffle(cardinality) ds = ds.cache( os.path.join(steves_utils.utils.get_datasets_base_path(), "caches", "less_limited_oracle")) # # Prime the cache # for e in ds.batch(1000): # pass # print("Buffer primed. Comment this out next time") # sys.exit(1) train_ds = ds.take(num_train) val_ds = ds.skip(num_train).take(num_val) test_ds = ds.skip(num_train + num_val).take(num_test) train_ds = train_ds.map( lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) val_ds = val_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) test_ds = test_ds.map(lambda x: (x["IQ"], tf.one_hot(x["serial_number_id"], RANGE)), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) train_ds = train_ds.batch(BATCH) val_ds = val_ds.batch(BATCH) test_ds = test_ds.batch(BATCH) train_ds = train_ds.prefetch(100) val_ds = val_ds.prefetch(100) test_ds = test_ds.prefetch(100) return train_ds, val_ds, test_ds