def test_input_fn(self): dataset = census_dataset.input_fn(self.input_csv, 1, False, 1) features, labels = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: features, labels = sess.run((features, labels)) # Compare the two features dictionaries. for key in TEST_INPUT_VALUES: self.assertTrue(key in features) self.assertEqual(len(features[key]), 1) feature_value = features[key][0] # Convert from bytes to string for Python 3. if isinstance(feature_value, bytes): feature_value = feature_value.decode() self.assertEqual(TEST_INPUT_VALUES[key], feature_value) self.assertFalse(labels)
def eval_input_fn(): return census_dataset.input_fn(test_file, 1, False, flags_obj.batch_size)
def train_input_fn(): return census_dataset.input_fn(train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def train_input_fn(): return census_dataset.input_fn( train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def input_fn(): return census_dataset.input_fn( TEST_CSV, num_epochs=num_epochs, shuffle=shuffle, batch_size=batch_size)
def easy_input_function(df, label_key, num_epochs, shuffle, batch_size): label = df[label_key] ds = tf.data.Dataset.from_tensor_slices((dict(df),label)) if shuffle: ds = ds.shuffle(10000) ds = ds.batch(batch_size).repeat(num_epochs) return ds import inspect print(inspect.getsource(census_dataset.input_fn)) ds = census_dataset.input_fn(train_file, num_epochs=5, shuffle=True, batch_size=10) for feature_batch, label_batch in ds.take(1): print('Feature keys:', list(feature_batch.keys())[:5]) print() print('Age batch :', feature_batch['age']) print() print('Label batch :', label_batch ) import functools train_inpf = functools.partial(census_dataset.input_fn, train_file, num_epochs=2, shuffle=True, batch_size=64) test_inpf = functools.partial(census_dataset.input_fn, test_file, num_epochs=1, shuffle=False, batch_size=64)
def print_ds(ds): for feature_batch, label_batch in ds.take(1): print("some feature keys: ", list(feature_batch.keys())) print() print("A batch of Ages: ", feature_batch["age"]) print() print("A batch of labels: ", label_batch) ds = easy_input_function(train_df, label_key='income_bracket', num_epochs=5, shuffle=True, batch_size=10) print_ds(ds) ds = census_dataset.input_fn(train_file, 5, True, 10) print_ds(ds) # using input functions print("################# input functions") import functools train_inpf = functools.partial(census_dataset.input_fn, train_file, 5, True, 64) test_inpf = functools.partial(census_dataset.input_fn, test_file, 1, False, 64) # feature columns print("################# feature columns") age = tf.feature_column.numeric_column("age") ds = train_inpf() feature_batch = next(iter(ds.take(1)))[0] print(tf.feature_column.input_layer(feature_batch, [age]).numpy())