Ejemplo n.º 1
0
 def fit(self, dataset):
     """Fit the Preprocessors."""
     x = dataset.map(lambda x, y: x)
     sources_x = data_utils.unzip_dataset(x)
     for pps_list, data in zip(self.inputs, sources_x):
         for preprocessor in pps_list:
             preprocessor.fit(data)
             data = preprocessor.transform(data)
     y = dataset.map(lambda x, y: y)
     sources_y = data_utils.unzip_dataset(y)
     for pps_list, data in zip(self.outputs, sources_y):
         for preprocessor in pps_list:
             preprocessor.fit(data)
             data = preprocessor.transform(data)
     return
Ejemplo n.º 2
0
 def _transform_data(self, dataset, pps_lists):
     sources = data_utils.unzip_dataset(dataset)
     transformed = []
     for pps_list, data in zip(pps_lists, sources):
         for preprocessor in pps_list:
             data = preprocessor.transform(data)
         transformed.append(data)
     if len(transformed) == 1:
         return transformed[0]
     return tuple(transformed)
Ejemplo n.º 3
0
 def _build_preprocessors(hp, hpps_lists, dataset):
     sources = data_utils.unzip_dataset(dataset)
     preprocessors_list = []
     for source, hpps_list in zip(sources, hpps_lists):
         data = source
         preprocessors = []
         for hyper_preprocessor in hpps_list:
             preprocessor = hyper_preprocessor.build(hp, data)
             data = preprocessor.transform(data)
             preprocessors.append(preprocessor)
         preprocessors_list.append(preprocessors)
     return preprocessors_list
Ejemplo n.º 4
0
 def _adapt(self, dataset, hms, batch_size):
     if isinstance(dataset, tf.data.Dataset):
         sources = data_utils.unzip_dataset(dataset)
     else:
         sources = nest.flatten(dataset)
     adapted = []
     for source, hm in zip(sources, hms):
         source = hm.get_adapter().adapt(source, batch_size)
         adapted.append(source)
     if len(adapted) == 1:
         return adapted[0]
     return tf.data.Dataset.zip(tuple(adapted))
Ejemplo n.º 5
0
def test_unzip_dataset_doesnt_unzip_single_dataset():
    dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(10, 32, 2))
    dataset = data_utils.unzip_dataset(dataset)[0]
    dataset = data_utils.unzip_dataset(dataset)[0]
    assert data_utils.dataset_shape(dataset).as_list() == [32, 2]