def pool(d, random_shuffler):
     for p in data.batch(d, self.batch_size * 100):
         p_batch = data.batch(
             sorted(p, key=self.sort_key),
             self.batch_size, self.batch_size_fn)
         for b in random_shuffler(list(p_batch)):
             yield b
示例#2
0
def pool(data,
         batch_size,
         key,
         batch_size_fn=lambda new, count, sofar: count,
         random_shuffler=None,
         shuffle=False,
         shuffle_batch_no=100000,
         sort_within_batch=False):
    """Sort within buckets, then batch, then shuffle batches.

    Partitions data into chunks of size 100*batch_size, sorts examples within
    each chunk using sort_key, then batch these examples and shuffle the
    batches.
    """

    if random_shuffler is None:
        random_shuffler = random.shuffle
    for p in batch(data, batch_size * shuffle_batch_no, batch_size_fn):
        p_batch = batch(sorted(p, key=key), batch_size, batch_size_fn) \
            if sort_within_batch \
            else batch(p, batch_size, batch_size_fn)
        if shuffle:
            for b in random_shuffler(list(p_batch)):
                yield b
        else:
            for b in list(p_batch):
                yield b
 def pool(d, batch_size, key, batch_size_fn=lambda new, count, sofar: count,
          random_shuffler=None, shuffle=False, sort_within_batch=False):
     """Sort within buckets, then batch, then shuffle batches.
     Partitions data into chunks of size 100*batch_size, sorts examples within
     each chunk using sort_key, then batch these examples and shuffle the
     batches.
     This pool function was changed to deal with larger batches -> the batch_size_fn input was removed
     for p in data.batch(d, batch_size * 100, batch_size_fn):
     """
     for p in data.batch(d, batch_size * 100):
         p_batch = data.batch(
             sorted(p, key=key), batch_size, batch_size_fn)
         for b in random_shuffler(list(p_batch)):
             yield b
示例#4
0
    def _do_batch(self) -> List[List[Example]]:
        groups: Dict[Any, List[Example]] = defaultdict(list)
        for example in self.data():
            groups[self.group_by(example)].append(example)

        minibatches = []
        for examples in groups.values():
            for minibatch in batch(examples, self.batch_size):
                minibatches.append(minibatch)
        return minibatches
示例#5
0
 def create_batches(self):
     if self.sort:
         self.batches = batch(self.data(), self.batch_size,
                              self.batch_size_fn)
     else:
         self.batches = pool(self.data(), self.batch_size,
                             self.sort_key, self.batch_size_fn,
                             random_shuffler=self.random_shuffler,
                             shuffle=self.shuffle,
                             sort_within_batch=self.sort_within_batch)
示例#6
0
    def dialogue_pool(self,
                      data,
                      batch_size,
                      key_inner,
                      key_outer,
                      batch_size_fn=lambda new, count, sofar: count,
                      random_shuffler=None):
        """Sort within buckets, then batch, then shuffle batches.

        Partitions data into chunks of size 100*batch_size, sorts examples within
        each chunk using sort_key, then batch these examples and shuffle the
        batches.
        """
        if random_shuffler is None:
            random_shuffler = random.shuffle
        for p in batch(data, batch_size * 100, batch_size_fn):
            p_batch = batch(sorted(sorted(p, key=key_inner), key=key_outer),
                            batch_size, batch_size_fn)
            for b in random_shuffler(list(p_batch)):
                yield b
示例#7
0
 def create_batches(self):
     if self.sort:
         self.batches = batch(self.data(), self.batch_size,
                              self.batch_size_fn)
     else:
         self.batches = self.dialogue_pool(
             self.data(),
             self.batch_size,
             self.sort_key_inner,
             self.sort_key_outer,
             self.batch_size_fn,
             random_shuffler=self.random_shuffler)
示例#8
0
 def create_batches(self):
     if self.train:
         def pool(d, random_shuffler):
             for p in data.batch(d, self.batch_size * 100):
                 p_batch = data.batch(sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn)
                 for b in random_shuffler(list(p_batch)):
                     yield b
         self.batches = pool(self.data(), self.random_shuffler)
     else:
         self.batches = []
         for b in data.batch(self.data(), self.batch_size, self.batch_size_fn):
                 self.batches.append(sorted(b, key=self.sort_key))
示例#9
0
文件: dataset.py 项目: snucclab/EPT
    def _generate_batches(self):
        """
        Make a generator for the batches.
        This method will enforce a batch have items with similar lengths.

        :return: This function yields a batched item (ProblemInstance)
            - text: ProblemTextInstance for given problem text
                - token: Long Tensor of index of text tokens. Shape [B, S],
                    where B = batch size and S = length of tokenized problem text sequence
                - pad: Bool Tensor for indicating padded positions, Shape [B, S].
                - number: Long Tensor for indicating number indices that a token belongs to. Shape [B, S].
                - number_value: Dictionary representing value of the numbers in the text.
            - op_gen: A LongTensor representing op-token indices. Shape [B, P],
                where P = length of op-token sequence.
            - expr_gen: A LongTensor representing expression-token indices (without pointer). Shape [B, X, 1+2A],
                where X = length of op-token sequence, and A = maximum arity.
            - expr_ptr: A LongTensor representing expression-token indices (with pointer). Shape [B, X, 1+2A]
            - index: List of problem IDs in the dataset
            - expected: List of expected answer tuples
        """
        max_token_size = 0
        items = []
        dataset = self._dataset

        # Chunk the dataset with much larger group of items than specified batch size.
        chunks = list(batch(dataset, self._batch_size * 1024, _get_item_size))
        for batch_group in chunks:
            # Sort within each group of items
            for item in sorted(batch_group, key=_get_token_length):
                items.append(item)

                # Compute the max-length key and new batch size.
                token_size = max(_get_token_length(item))
                max_token_size = max(max_token_size, token_size)
                batch_size = max_token_size * len(items)

                # If the size exceeded, flush it.
                if batch_size == self._batch_size:
                    yield self._concatenate_batch(items)
                    items = []
                    max_token_size = 0
                elif batch_size > self._batch_size:
                    yield self._concatenate_batch(items[:-1])
                    items = items[-1:]
                    max_token_size = token_size

            # If items is not empty, flush the last chunk.
            if items:
                yield self._concatenate_batch(items)
    def create_batches(self):
        if self.train:

            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(sorted(p, key=self.sort_key),
                                         self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b

            # Note: The `pool` distributes the tasks to the available processors using a FIFO scheduling.
            self.batches = pool(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))
示例#11
0
 def create_batches(self):
     if self.train:
         self.batches = self.pool(self.data(), self.random_shuffler)
     else:
         self.batches = data.batch(self.data(), self.batch_size,
                                   self.batch_size_fn)
示例#12
0
 def create_batches(self):
     self.batches = []
     for b in batch(self.data(), self.batch_size, self.batch_size_fn):
         self.batches.append(sorted(b, key=self.sort_key))