示例#1
0
文件: models.py 项目: jackd/kblocks
def as_infinite_iterator(
        dataset: tf.data.Dataset,
        steps_per_epoch: Optional[int] = None) -> Tuple[tf.data.Iterator, int]:
    """
    Get an iterator for an infinite dataset and steps_per_epoch.

    Args:
        dataset: possibly infinite dataset.
        steps_per_epoch: number of steps per epoch if `dataset` has infinite
            cardinality, otherwise `None` or `dataset`'s cardinality.

    Returns:
        iterator: tf.data.Iterator of possibly repeated `dataset`.
        steps_per_epoch: number of elements in iterator considered one epoch.

    Raises:
        ValueError is dataset has finite cardinality inconsistent with steps_per_epoch.
    """
    cardinality = tf.keras.backend.get_value(dataset.cardinality())
    if steps_per_epoch is None:
        steps_per_epoch = cardinality
        if cardinality == tf.data.INFINITE_CARDINALITY:
            raise ValueError(
                "steps_per_epoch must be provided if dataset has infinite "
                "cardinality")
        dataset = dataset.repeat()
    elif cardinality != tf.data.INFINITE_CARDINALITY:
        assert cardinality == steps_per_epoch
        dataset = dataset.repeat()
    return iter(dataset), steps_per_epoch
def _train_bert_multitask_keras_model(
        train_dataset: tf.data.Dataset,
        eval_dataset: tf.data.Dataset,
        model: tf.keras.Model,
        params: BaseParams,
        mirrored_strategy: tf.distribute.MirroredStrategy = None):
    # can't save whole model with model subclassing api due to tf bug
    # see: https://github.com/tensorflow/tensorflow/issues/42741
    # https://github.com/tensorflow/tensorflow/issues/40366
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(params.ckpt_dir, 'model'),
        save_weights_only=True,
        monitor='val_mean_acc',
        mode='auto',
        save_best_only=True)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=params.ckpt_dir)
    if mirrored_strategy is not None:
        with mirrored_strategy.scope():
            model.fit(
                x=train_dataset.repeat(),
                validation_data=eval_dataset,
                epochs=params.train_epoch,
                callbacks=[model_checkpoint_callback, tensorboard_callback],
                steps_per_epoch=params.train_steps_per_epoch)
    else:
        model.fit(x=train_dataset.repeat(),
                  validation_data=eval_dataset,
                  epochs=params.train_epoch,
                  callbacks=[model_checkpoint_callback, tensorboard_callback],
                  steps_per_epoch=params.train_steps_per_epoch)
    model.summary()
示例#3
0
  def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if self._num_gpus > 1:
      dataset = dataset.shard(self._num_gpus, hvd.rank())

    if self.is_training:
      # Shuffle the input files.
      dataset.shuffle(buffer_size=self._file_shuffle_buffer_size)

    if self.is_training and not self._cache:
      dataset = dataset.repeat()

    # Read the data from disk in parallel
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=10,
        block_length=1,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._cache:
      dataset = dataset.cache()

    if self.is_training:
      dataset = dataset.shuffle(self._shuffle_buffer_size)
      dataset = dataset.repeat()

    # Parse, pre-process, and batch the data in parallel
    preprocess = self.parse_record
    dataset = dataset.map(preprocess,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._num_gpus > 1:
      # The batch size of the dataset will be multiplied by the number of
      # replicas automatically when strategy.distribute_datasets_from_function
      # is called, so we use local batch size here.
      dataset = dataset.batch(self.local_batch_size,
                              drop_remainder=self.is_training)
    else:
      dataset = dataset.batch(self.global_batch_size,
                              drop_remainder=self.is_training)

    # Apply Mixup
    mixup_alpha = self.mixup_alpha if self.is_training else 0.0
    dataset = dataset.map(
        functools.partial(self.mixup, self.local_batch_size, mixup_alpha),
        num_parallel_calls=64)

    # Prefetch overlaps in-feed with training
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset
示例#4
0
def preprocess(dataset: tf.data.Dataset, feature_layer: tf.keras.layers,
               target_feature: str,
               num_epochs: int,
               shuffle_buffer: int,
               batch_size: int,
               batches_to_take=None):
    """
    Preprocess data with a single-element label (label of length one).

    :param dataset: the dataset to preprocess.
    :param feature_layer: feature layer to use to preprocess the data.
    :param target_feature: the name of the target feature (used to extract the correct
    element from the input observations).
    :param num_epochs: number of epochs to repeat for; by default, it is set to.
    :return:
    """

    def element_fn(element):
        # element_fn extracts feature and label vectors from each element;
        # 'x' and 'y' names are required by keras.
        feature_vector = feature_layer(element)

        return collections.OrderedDict([
            ('x', tf.reshape(feature_vector, [feature_vector.shape[1]])),
            ('y', tf.reshape(element[target_feature], [1])),
        ])

    preprocessed_dataset = dataset.repeat(num_epochs).map(element_fn).shuffle(
        shuffle_buffer).batch(batch_size)
    if not batches_to_take:
        return preprocessed_dataset
    else:
        return preprocessed_dataset.take(batches_to_take)
示例#5
0
def get_augmented_data(
    dataset: tf.data.Dataset,
    batch_size: int,
    map_func: Callable,
    shuffle_buffer: Optional[int] = None,
    shuffle_seed: Optional[int] = None,
    augment_seed: Optional[int] = None,
    use_stateless_map: bool = False,
) -> RepeatedData:
    if shuffle_buffer is not None:
        dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed)
    dataset = dataset.batch(batch_size)
    steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality())
    # repeat before map so stateless map is different across epochs
    dataset = dataset.repeat()
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    if use_stateless_map:
        dataset = dataset.apply(
            tfrng.data.stateless_map(
                map_func,
                seed=augment_seed,
                num_parallel_calls=AUTOTUNE,
            ))
    else:
        # if map_func has random elements this won't be deterministic
        dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)
    return RepeatedData(dataset, steps_per_epoch)
示例#6
0
文件: repeated.py 项目: jackd/kblocks
 def __init__(self,
              dataset: tf.data.Dataset,
              steps_per_epoch: Optional[int] = None):
     cardinality = tf.keras.backend.get_value(dataset.cardinality())
     if steps_per_epoch is None:
         steps_per_epoch = cardinality
         if cardinality == tf.data.INFINITE_CARDINALITY:
             raise ValueError(
                 "steps_per_epoch must be provided if dataset has infinite "
                 "cardinality")
         dataset = dataset.repeat()
     elif cardinality != tf.data.INFINITE_CARDINALITY:
         assert cardinality == steps_per_epoch
         dataset = dataset.repeat()
     self._dataset = dataset
     self._steps_per_epoch = steps_per_epoch
示例#7
0
    def process(self, dataset: tf.data.Dataset, batch_size: int):
        dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE)

        if self.cache:
            dataset = dataset.cache()

        if self.shuffle:
            dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True)

        if self.indefinite:
            dataset = dataset.repeat()

        # PADDED BATCH the dataset
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=(
                tf.TensorShape([]),
                tf.TensorShape(self.speech_featurizer.shape),
                tf.TensorShape([]),
                tf.TensorShape(self.text_featurizer.shape),
                tf.TensorShape([]),
                tf.TensorShape(self.text_featurizer.prepand_shape),
                tf.TensorShape([]),
            ),
            padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0),
            drop_remainder=self.drop_remainder
        )

        # PREFETCH to improve speed of input length
        dataset = dataset.prefetch(AUTOTUNE)
        self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder)
        return dataset
示例#8
0
def iterator_from_dataset(
    dataset: tf.data.Dataset,
    batch_size: int,
    repeat: bool = True,
    prefetch_size: int = 0,
    devices: Optional[Sequence[Any]] = None,
):
    """Create a data iterator that returns JAX arrays from a TF dataset.

    Args:
      dataset: the dataset to iterate over.
      batch_size: the batch sizes the iterator should return.
      repeat: whether the iterator should repeat the dataset.
      prefetch_size: the number of batches to prefetch to device.
      devices: the devices to prefetch to.

    Returns:
      An iterator that returns data batches.
    """
    if repeat:
        dataset = dataset.repeat()

    if batch_size > 0:
        dataset = dataset.batch(batch_size)
        it = map(prepare_tf_data, dataset)
    else:
        it = map(prepare_tf_data_unbatched, dataset)

    if prefetch_size > 0:
        it = jax_utils.prefetch_to_device(it, prefetch_size, devices)

    return it
示例#9
0
    def get_test_tfdataset(self,
                           test_dataset: tf.data.Dataset) -> tf.data.Dataset:
        """
        Returns a test :class:`~tf.data.Dataset`.

        Args:
            test_dataset (:class:`~tf.data.Dataset`): The dataset to use.

        Subclass and override this method if you want to inject some custom behavior.
        """

        num_examples = tf.data.experimental.cardinality(test_dataset).numpy()

        if num_examples < 0:
            raise ValueError(
                "The training dataset must have an asserted cardinality")

        approx = math.floor if self.args.dataloader_drop_last else math.ceil
        steps = approx(num_examples / self.args.eval_batch_size)
        ds = (test_dataset.repeat().batch(
            self.args.eval_batch_size,
            drop_remainder=self.args.dataloader_drop_last).prefetch(
                tf.data.experimental.AUTOTUNE))

        return self.args.strategy.experimental_distribute_dataset(
            ds), steps, num_examples
示例#10
0
def preprocess(dataset: tf.data.Dataset, num_epoch: int,
               batch_size: int) -> tf.data.Dataset:
    def batch_format_fn(element: Dict[str, tf.Tensor]):
        return (tf.expand_dims(element["pixels"], axis=-1), element["label"])

    return dataset.repeat(num_epoch).shuffle(100).batch(batch_size).map(
        batch_format_fn)
示例#11
0
def run_distilibert(strategy: tf.distribute.TPUStrategy, x_train: np.array,
                    x_valid: np.array, _y_train: np.array, y_valid: np.array,
                    train_dataset: tf.data.Dataset,
                    valid_dataset: tf.data.Dataset,
                    test_dataset: tf.data.Dataset, max_len: int, epochs: int,
                    batch_size: int) -> tf.keras.models.Model:
    """
    create and run distilbert on training and testing data
    """
    logger.info('build distilbert')

    with strategy.scope():
        transformer_layer = TFDistilBertModel.from_pretrained(MODEL)
        model = build_model(transformer_layer, max_len=max_len)
    model.summary()

    # train given model
    n_steps = x_train.shape[0] // batch_size
    history = model.fit(train_dataset,
                        steps_per_epoch=n_steps,
                        validation_data=valid_dataset,
                        epochs=epochs)
    plot_train_val_loss(history, 'distilbert')

    n_steps = x_valid.shape[0] // batch_size
    _train_history_2 = model.fit(valid_dataset.repeat(),
                                 steps_per_epoch=n_steps,
                                 epochs=epochs * 2)

    scores = model.predict(test_dataset, verbose=1)
    logger.info(f"AUC: {roc_auc(scores, y_valid):.4f}")

    return model
示例#12
0
文件: data.py 项目: ak110/pytoolkit
def mixup(
    ds: tf.data.Dataset,
    postmix_fn: typing.Callable[..., typing.Any] = None,
    num_parallel_calls: int = None,
):
    """tf.dataでのmixup: <https://arxiv.org/abs/1710.09412>

    Args:
        ds: 元のデータセット
        postmix_fn: mixup後の処理
        num_parallel_calls: premix_fnの並列数

    """
    @tf.function
    def mixup_fn(*data):
        r = _tf_random_beta(alpha=0.2, beta=0.2)
        data = [
            tf.cast(d[0], tf.float32) * r + tf.cast(d[1], tf.float32) * (1 - r)
            for d in data
        ]
        return data if postmix_fn is None else postmix_fn(*data)

    ds = ds.repeat()
    ds = ds.batch(2)
    ds = ds.map(
        mixup_fn,
        num_parallel_calls=num_parallel_calls,
        deterministic=None if num_parallel_calls is None else False,
    )
    return ds
 def train(self, dataset: tf.data.Dataset, nr_records: int):
     dataset = dataset.batch(self.batch_size).map(self.transform_example)
     dataset = dataset.repeat()
     dataset = dataset.shuffle(1000)
     self.model.fit(dataset,
                    epochs=self.epochs,
                    steps_per_epoch=nr_records // self.batch_size)
    def get_test_tfdataset(self,
                           test_dataset: tf.data.Dataset) -> tf.data.Dataset:
        """
        Returns a test :class:`~tf.data.Dataset`.

        Args:
            test_dataset (:class:`~tf.data.Dataset`):
                The dataset to use. The dataset should yield tuples of ``(features, labels)`` where ``features`` is
                a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is
                calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such
                as when using a QuestionAnswering head model with multiple targets, the loss is instead calculated
                by calling ``model(features, **labels)``.

        Subclass and override this method if you want to inject some custom behavior.
        """

        num_examples = tf.data.experimental.cardinality(test_dataset).numpy()

        if num_examples < 0:
            raise ValueError(
                "The training dataset must have an asserted cardinality")

        approx = math.floor if self.args.dataloader_drop_last else math.ceil
        steps = approx(num_examples / self.args.eval_batch_size)
        ds = (test_dataset.repeat().batch(
            self.args.eval_batch_size,
            drop_remainder=self.args.dataloader_drop_last).prefetch(
                tf.data.experimental.AUTOTUNE))

        return self.args.strategy.experimental_distribute_dataset(
            ds), steps, num_examples
示例#15
0
def preprocess_dataset(dataset: tf.data.Dataset, batch_size: int, n_step_returns: int, discount: float):
  d_len = sum([1 for _ in dataset])
  dataset = dataset.map(lambda *x:
                               n_step_transition_from_episode(*x, n_step=n_step_returns,
                                                              additional_discount=discount))
  dataset = dataset.repeat().shuffle(d_len).batch(batch_size, drop_remainder=True)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset
def prepare_for_testing(data_set: tf.data.Dataset, batch_size, cache_path=''):
    if cache_path != '':
        cache_filename = 'dataset_test.tfcache'
        data_set = data_set.cache(''.join([cache_path, '/', cache_filename]))

    data_set = data_set.repeat()
    data_set = data_set.batch(batch_size=batch_size)

    return data_set
示例#17
0
def _prepare_test_dataset(dataset: tf.data.Dataset, batch_size, cache_path=''):
    if cache_path != '':
        cache_filename = 'dataset_test.tfcache'
        dataset = dataset.cache(
            os.path.join(opt.data_path, cache_path, cache_filename))
        # dataset = dataset.cache(''.join([cache_path, '/', cache_filename]))

    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size=batch_size)

    return dataset
 def preprocess_fn(dataset: tf.data.Dataset) -> tf.data.Dataset:
   if shuffle_buffer_size > 1:
     dataset = dataset.shuffle(shuffle_buffer_size, seed=debug_seed)
   if preprocess_spec.num_epochs > 1:
     dataset = dataset.repeat(preprocess_spec.num_epochs)
   if preprocess_spec.max_elements is not None:
     dataset = dataset.take(preprocess_spec.max_elements)
   dataset = dataset.batch(preprocess_spec.batch_size, drop_remainder=False)
   return dataset.map(
       mapping_fn,
       num_parallel_calls=num_parallel_calls,
       deterministic=debug_seed is not None)
示例#19
0
def train_fn(ds: tf.data.Dataset,
             batch_size=1,
             shuffle=10000,
             repeat: int = None):
    '''Create input function for training, prediction, evaluation.'''

    if shuffle:
        ds = ds.shuffle(shuffle)
    ds = ds.batch(batch_size)
    if repeat != 1:
        ds = ds.repeat(repeat)

    return lambda: ds.make_one_shot_iterator().get_next()
 def fit(
         self,
         train: tf.data.Dataset,
         valid: Optional[tf.data.Dataset] = None,
         valid_freq=500,
         valid_interval=0,
         optimizer='adam',
         learning_rate=1e-3,
         clipnorm=None,
         epochs=-1,
         max_iter=1000,
         sample_shape=(),  # for ELBO
         analytic=False,  # for ELBO
         iw=False,  # for ELBO
         callback=lambda: None,
         compile_graph=True,
         autograph=False,
         logging_interval=2,
         skip_fitted=False,
         log_tag='',
         log_path=None):
     if self.is_fitted and skip_fitted:
         return self
     from odin.exp.trainer import Trainer
     trainer = Trainer()
     self.trainer = trainer
     # create the optimizer
     if optimizer is not None and self.optimizer is None:
         self.optimizer = _to_optimizer(optimizer, learning_rate, clipnorm)
     if self.optimizer is None:
         raise RuntimeError("No optimizer found!")
     self._trainstep_kw = dict(sample_shape=sample_shape,
                               iw=iw,
                               elbo_kw=dict(analytic=analytic))
     # if already called repeat, then no need to repeat more
     if hasattr(train, 'repeat'):
         train = train.repeat(int(epochs))
     trainer.fit(train_ds=train,
                 optimize=self.optimize,
                 valid_ds=valid,
                 valid_freq=valid_freq,
                 valid_interval=valid_interval,
                 compile_graph=compile_graph,
                 autograph=autograph,
                 logging_interval=logging_interval,
                 log_tag=log_tag,
                 log_path=log_path,
                 max_iter=max_iter,
                 callback=callback)
     self._trainstep_kw = dict()
     return self
示例#21
0
def _prepare_dataset(
        dataset: tf.data.Dataset,
        global_batch_size: int,
        shuffle: bool,
        rng: np.ndarray,
        preprocess_fn: Optional[Callable[[Any], Any]] = None,
        num_epochs: Optional[int] = None,
        filter_fn: Optional[Callable[[Any], Any]] = None) -> tf.data.Dataset:
    """Batches, shuffles, prefetches and preprocesses a dataset.

  Args:
    dataset: The dataset to prepare.
    global_batch_size: The global batch size to use.
    shuffle: Whether the shuffle the data on example level.
    rng: PRNG for seeding the shuffle operations.
    preprocess_fn: Preprocessing function that will be applied to every example.
    num_epochs: Number of epochs to repeat the dataset.
    filter_fn: Funtion that filters samples according to some criteria.

  Returns:
    The dataset.
  """
    if shuffle and rng is None:
        raise ValueError("Shuffling without RNG is not supported.")

    if global_batch_size % jax.host_count() != 0:
        raise ValueError(
            f"Batch size {global_batch_size} not divisible by number "
            f"of hosts ({jax.host_count()}).")
    local_batch_size = global_batch_size // jax.host_count()
    batch_dims = [jax.local_device_count(), local_batch_size]

    # tf.data uses single integers as seed.
    if rng is not None:
        rng = rng[0]

    ds = dataset.repeat(num_epochs)
    if shuffle:
        ds = ds.shuffle(1024, seed=rng)

    if preprocess_fn is not None:
        ds = ds.map(preprocess_fn,
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if filter_fn is not None:
        ds = ds.filter(filter_fn)

    for batch_size in reversed(batch_dims):
        ds = ds.batch(batch_size, drop_remainder=True)
    return ds.prefetch(tf.data.experimental.AUTOTUNE)
示例#22
0
    def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
        """Create a dataset with repeated loops over the input elements.

        Args:
            ds_input: Any dataset.

        Returns:
            A `tf.data.Dataset` with elements containing the same keys, but repeated for
            `epochs` iterations.
        """
        if self.repeat:
            return ds_input.repeat(count=self.epochs)
        else:
            return ds_input
示例#23
0
 def fit(self, data: tf.data.Dataset, epochs=1, steps_per_epoch=1,
         validation_data=None, validation_steps=1,
         **flow_kwargs):
     data.repeat(epochs)
     if validation_data is not None:
         validation_data = validation_data.repeat(epochs)
     test_hist = dict()
     for epoch in range(epochs):
         train_hist = dict()
         with tqdm(total=steps_per_epoch, desc=f'train, epoch {epoch+1}/{epochs}') as prog:
             for i, (x, y) in enumerate(data.take(steps_per_epoch)):
                 loss, nll = self.train_batch(x, y, **flow_kwargs)
                 utils.update_metrics(train_hist, loss=loss.numpy(), nll=nll.numpy())
                 prog.update(1)
                 prog.set_postfix(utils.get_metrics(train_hist))
         with tqdm(total=validation_steps, desc=f'test, epoch {epoch+1}/{epochs}') as prog:
             if validation_data is None:
                 continue
             for i, (x, y) in enumerate(validation_data.take(validation_steps)):
                 nll = self.eval_batch(x, y, **flow_kwargs)
                 utils.update_metrics(test_hist, nll=nll.numpy())
                 prog.update(1)
                 prog.set_postfix(utils.get_metrics(test_hist))
     return test_hist
示例#24
0
    def repeat(data: tf.data.Dataset) -> tf.data.Dataset:
        """
        Repeat dataset

        Parameters
        ----------
        data
            tensorflow dataset to cache

        Returns
        -------
        data_repeated
            repeated data
        """
        data = data.repeat()
        return data
示例#25
0
def prepare_dataset(
    dataset: tf.data.Dataset,
    model_image_size: Tuple[int, int],
    augmentation_fn: Optional[ImageDataMapFn] = None,
    num_epochs: Optional[int] = None,
    batch_size: Optional[int] = None,
    shuffle_buffer_size: Optional[int] = None,
    num_parallel_calls: Optional[int] = None,
    prefetch_buffer_size: Optional[int] = None,
    prefetch_to_device: Optional[str] = None,
) -> tf.data.Dataset:

    # apply data augmentation:
    if augmentation_fn is not None:
        dataset = dataset.map(
            map_image_data(augmentation_fn),
            num_parallel_calls=num_parallel_calls,
        )

    if shuffle_buffer_size is not None:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(
        map_image_data(prepare_for_batching(model_image_size)),
        num_parallel_calls=num_parallel_calls,
    )

    # batching and padding
    if batch_size is not None:
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=get_padding_shapes(
                dataset, spatial_image_shape=model_image_size),
            drop_remainder=True,
        )

    # try to prefetch dataset on certain device
    if prefetch_to_device is not None:
        prefetch_fn = tf.data.experimental.prefetch_to_device(
            device=prefetch_to_device, buffer_size=prefetch_buffer_size)
        dataset = dataset.apply(prefetch_fn)
    else:
        if prefetch_buffer_size is not None:
            dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

    return dataset
示例#26
0
    def get_tfds_data_loader(data : tf.data.Dataset, data_subset_mode='train', batch_size=32, num_samples=100, num_classes=19, infinite=True, augment=True, seed=2836):


        def encode_example(x, y):
            x = tf.image.convert_image_dtype(x, tf.float32) * 255.0
            y = _encode_label(y, num_classes=num_classes)
            return x, y

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.shuffle(buffer_size=num_samples) \
                   .cache() \
                   .map(encode_example, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.map(preprocess_input, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        if data_subset_mode == 'train':
            data = data.shuffle(buffer_size=100, seed=seed)
            augmentor = TRAIN_image_augmentor
        elif data_subset_mode == 'val':
            augmentor = VAL_image_augmentor
        elif data_subset_mode == 'test':
            augmentor = TEST_image_augmentor

        if augment:
            data = augmentor.apply_augmentations(data)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.batch(batch_size, drop_remainder=True)
        if infinite:
            data = data.repeat()

        return data.prefetch(AUTOTUNE)
def prepare_for_training(data_set: tf.data.Dataset,
                         batch_size,
                         cache_path=None,
                         shuffle_buffer_size=1000):
    if cache_path != '':
        cache_filename = 'dataset_train.tfcache'
        data_set = data_set.cache(''.join([cache_path, '/', cache_filename]))

    data_set = data_set.shuffle(buffer_size=shuffle_buffer_size)
    # repeat forever
    data_set = data_set.repeat()
    data_set = data_set.batch(batch_size=batch_size)
    # `prefetch` lets the dataset fetch batches in the background
    # while the model is training.
    data_set = data_set.prefetch(buffer_size=AUTOTUNE)

    return data_set
示例#28
0
def preprocess_tf_dataset(dataset: tf.data.Dataset,
                          hparams: ClientDataHParams) -> tf.data.Dataset:
    """Preprocesses dataset according to the dataset hyperparmeters.

  Args:
    dataset: Dataset with a mapping element structure.
    hparams: Hyper parameters for dataset preparation.

  Returns:
    Preprocessed dataset.
  """
    dataset = dataset.repeat(hparams.num_epochs)
    if hparams.shuffle_buffer_size:
        dataset = dataset.shuffle(hparams.shuffle_buffer_size)
    dataset = (dataset.batch(
        hparams.batch_size, drop_remainder=hparams.drop_remainder).prefetch(1))
    return dataset.take(hparams.num_batches)
示例#29
0
文件: profile.py 项目: jackd/kblocks
def profile_model(
    model: tf.keras.Model,
    dataset: tf.data.Dataset,
    inference_only: bool = False,
    **kwargs,
):
    if dataset.cardinality() != tf.data.INFINITE_CARDINALITY:
        dataset = dataset.repeat()
    it = iter(dataset)
    model_func = (model.make_predict_function()
                  if inference_only else model.make_train_function())

    def func():
        return model_func(it)

    return profile_func(func,
                        **kwargs,
                        name="predict" if inference_only else "train")
示例#30
0
def benchmark_model(model: tf.keras.Model,
                    dataset: tf.data.Dataset,
                    inference_only=False,
                    **kwargs):
    if dataset.cardinality() != tf.data.INFINITE_CARDINALITY:
        dataset = dataset.repeat()
    inputs, labels, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(
        as_inputs(dataset))
    if inference_only:
        op = model(inputs)
    else:
        variables = model.trainable_variables
        with tf.GradientTape() as tape:
            predictions = model(inputs)
            loss = model.loss(labels, predictions, sample_weight=sample_weight)
        grads = tape.gradient(loss, variables)
        op = model.optimizer.apply_gradients(zip(grads, variables))
    return benchmark_op(op, **kwargs)