Esempio n. 1
0
def is_primary_worker(scope='global'):
    """Check whether is the primary worker of all nodes (global) or the current node (local).

  Args:
  * scope: check scope ('global' OR 'local')

  Returns:
  * flag: whether is the primary worker
  """

    if scope == 'global':
        return True if not FLAGS.enbl_multi_gpu else mgw.rank() == 0
    elif scope == 'local':
        return True if not FLAGS.enbl_multi_gpu else mgw.local_rank() == 0
    else:
        raise ValueError('unrecognized worker scope: ' + scope)
Esempio n. 2
0
    def build(self, enbl_trn_val_split=False):
        '''Build iterator(s) for tf.data.Dataset() object.

    Args:
    * enbl_trn_val_split: whether to split into training & validation subsets

    Returns:
    * iterator_trn: iterator for the training subset
    * iterator_val: iterator for the validation subset
      OR
    * iterator: iterator for the chosen subset (training OR testing)

    Example:
      # build iterator(s)
      dataset = xxxxDataset(is_train=True)  # TF operations are not created
      iterator = dataset.build()            # TF operations are created
          OR
      iterator_trn, iterator_val = dataset.build(enbl_trn_val_split=True)  # for dataset-train only

      # use the iterator to obtain a mini-batch of images & labels
      images, labels = iterator.get_next()
    '''

        # obtain list of data files' names
        filenames = tf.data.Dataset.list_files(self.file_pattern, shuffle=True)
        if self.enbl_shard:
            filenames = filenames.shard(mgw.size(), mgw.rank())

        # create a tf.data.Dataset from list of files
        dataset = filenames.apply(
            tf.contrib.data.parallel_interleave(
                self.dataset_fn, cycle_length=FLAGS.cycle_length))
        dataset = dataset.map(self.parse_fn,
                              num_parallel_calls=FLAGS.nb_threads)

        # create iterators for training & validation subsets separately
        if self.is_train and enbl_trn_val_split:
            iterator_val = self.__make_iterator(
                dataset.take(FLAGS.nb_smpls_val))
            iterator_trn = self.__make_iterator(
                dataset.skip(FLAGS.nb_smpls_val))
            return iterator_trn, iterator_val

        return self.__make_iterator(dataset)
Esempio n. 3
0
 def __is_primary_worker(cls):
   return not FLAGS.enbl_multi_gpu or mgw.rank() == 0
Esempio n. 4
0
 def __is_primary_worker(cls):
   """Weather it is the primary worker"""
   return not FLAGS.enbl_multi_gpu or mgw.rank() == 0