Beispiel #1
0
 def shuffle_with_seed(self, dataset, ds_kwargs):
   if not 'seed' in ds_kwargs or ds_kwargs['seed'] is None:
     logger.warn("Shuffling with fixed shuffle seed {}.".format(self.shuffle_seed))
     ds_kwargs['seed'] = self.shuffle_seed
   else:
     logger.debug("Shuffling with shuffle seed {}.".format(ds_kwargs['seed']))
   return dataset.shuffle(**ds_kwargs)
Beispiel #2
0
    def _create_tnt_model(cls, model: tf.keras.Model,
                          parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \
                                                                                             else tnt.ParallelStrategy.DATA,
                          num_pipeline_stages: int = 1):
        replica_group = tnt.Group()

        if (tnt.ParallelStrategy.PIPELINING
                in parallel_strategy) and isinstance(model,
                                                     tf.keras.Sequential):
            logger.warn(
                f"Cannot pipeline a `tf.keras.Sequential` model; disabling model parallelism."
            )
            parallel_strategy = parallel_strategy ^ tnt.ParallelStrategy.PIPELINING

        logger.info(f"Creating parallel model using {parallel_strategy}.")
        if tnt.ParallelStrategy.PIPELINING in parallel_strategy:
            rank = tnt.get_rank()

            partition_generator = pgen.GraphPartitionGenerator(model)
            rank_mapper = rmapper.RankMapper(
                num_ranks=tnt.get_size(),
                pipeline_graph=partition_generator.get_pipeline_graph())
            pipeline_group = rank_mapper.get_pipelining_group_for_rank(rank)

            logger.info(
                f"[Pipelining] Creating pipelined model with {pipeline_group.size} partitions."
            )
            # get my partition
            model = pm.PartitionedModel(
                model=model,
                group=pipeline_group,
                partition_generator=partition_generator,
                rank_mapper=rank_mapper,
                num_pipeline_stages=num_pipeline_stages)
            if tnt.ParallelStrategy.DATA in parallel_strategy:
                replica_group = rank_mapper.get_replica_group_for_rank(rank)
            else:
                if pipeline_group.size != tnt.get_size():
                    raise ValueError(
                        f"Provided model has only {pipeline_group.size} partitions; use {pipeline_group.size} ranks or a different parallel strategy."
                    )

        if tnt.ParallelStrategy.DATA in parallel_strategy:
            # replicate my partition across the data parallel group
            logger.info(
                f"[DataParallel] Replicating local model across ranks {replica_group.group}."
            )
            model = dpm.DataParallelModel(model=model, group=replica_group)
        return model
Beispiel #3
0
  def _configure_rebuild(self, dataset):
    self.built = False
    dist_dataset = tnt.data.Dataset(dataset = dataset,
                                    num_ranks = 1,
                                    rank = 0)
    dist_dataset.distribute_dataset_across_ranks(apply_batch = False)

    # model is already built with the same `nano_batch_size`
    if self.nano_batch_size == dist_dataset.micro_batch_size // self.num_pipeline_stages:
      self.built = True
      return

    micro_batch_size = dist_dataset.micro_batch_size
    self.nano_batch_size = micro_batch_size // self.num_pipeline_stages
    if self.nano_batch_size * self.num_pipeline_stages != micro_batch_size:
      logger.warn(f"[PartitionedModel] The micro-batch size {self.micro_batch_size} is not a multiple of "
                  f" the number of pipeline stages ({self.num_pipeline_stages}); removing the remainder.")
def _pad_dataset_if_necessary(dataset, num_samples, batch_size,
                              min_last_batch_size):
    last_batch_size = _get_last_incomplete_batch_size(num_samples, batch_size)
    if last_batch_size == 0:
        logger.debug(f"No padding required: number of samples {num_samples} is a multiple " \
                     f"of the batch size {batch_size}.")
        return dataset

    logger.info(f"Incomplete last batch in the dataset: number of samples is " \
                f"{last_batch_size} ( != batch size {batch_size}).")

    if version_utils.tf_version_below_equal('2.1'):
        num_samples_multiple = num_samples - last_batch_size
        logger.warn(f"Number of samples ({num_samples}) is not a multiple of batch size. " \
                    f"This use case is not supported in TF v{version_utils.current_version()}. " \
                    f"Dropping the last incomplete batch from the dataset, "\
                    f"and proceeding with {num_samples_multiple} samples.")
        return dataset.take(num_samples_multiple)

    if last_batch_size < min_last_batch_size:
        logger.debug(f"Padding required for the last batch: number of samples is " \
                     f"{last_batch_size} ( < min_batch_size {min_last_batch_size}).")

        # Create helper dataset that contains one full batch and one incomplete batch
        helper_dataset = dataset.take(min_last_batch_size + last_batch_size)
        helper_dataset = helper_dataset.batch(min_last_batch_size,
                                              drop_remainder=False)

        # If `padded_shape` is unspecified, all dimensions of all components
        # are padded to the maximum size in the batch.
        # The second batch in `helper_dataset` will now contain `min_last_batch_size - last_batch_size`
        # default-initialized samples.
        helper_dataset = helper_dataset.padded_batch(2)

        # Switch back to a list of samples instead of batches
        helper_dataset = helper_dataset.unbatch().unbatch()

        # Remaining samples in the dataset are those generated through padding
        padding_samples = helper_dataset.skip(min_last_batch_size +
                                              last_batch_size)
        dataset = dataset.concatenate(padding_samples)
        logger.info(f"[Rank {tnt.get_rank()}] Dataset padded with " \
                    f"{min_last_batch_size - last_batch_size} samples.")
    return dataset
Beispiel #5
0
  def distributed_batch(self, dataset, batch_size, micro_batch_size):
    if self.batching_info.drop_remainder == True:
      dataset = self.batching_info.apply(dataset, new_batch_size = batch_size)
      dataset = dataset.unbatch()

    else: # no drop remainder
      num_samples = ds_helpers.get_num_samples(dataset)
      if num_samples == tf.data.experimental.INFINITE_CARDINALITY:
        raise ValueError("[DistributedDataset] Infinite dataset provided")

      # Total number of samples is not multiple of the batch size
      if num_samples % batch_size != 0:
        logger.warn("Number of samples ({}) is not a multiple of batch size.\
 Removing the last incomplete batch from the dataset.".format(num_samples))
        num_samples_multiple = (num_samples // batch_size) * batch_size
        dataset = dataset.take(num_samples_multiple)

    dataset = self.batching_info.apply(dataset, new_batch_size = micro_batch_size)
    dataset = dataset.shard(num_shards=self.num_ranks, index = self.rank)

    logger.info("Using batch size = {}, micro batch size = {}.".format(
                batch_size, micro_batch_size))
    return dataset
Beispiel #6
0
 def _validate_micro_batch_size_for_batch_normalization(self, micro_batch_size):
   if micro_batch_size < 16:
     for layer in self.layers:
       if isinstance(layer, tf.keras.layers.BatchNormalization):
         logger.warn("Micro batch size should be at least 16 when using Batch Normalization.")
         return