예제 #1
0
파일: iterator.py 프로젝트: cesarali/Tyche
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1

                seq_len, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time'])])

                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti, l)) for ti, l in zip(time, seq_len))

                if not self.repeat:
                    return
예제 #2
0
파일: iterator.py 프로젝트: cesarali/Tyche
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                seq_len, text, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('text', self.dataset.fields['text']),
                    ('target_time', Field(use_vocab=False)), ('target_text', Field(use_vocab=False))])
                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti[:, :, :2], l),
                        text=te[:, :, 0],
                        target_time=ti[:, :, -1],
                        target_text=te[:, :, 1]) for ti, te, l in zip(time, text, seq_len))

            if not self.repeat:
                return
예제 #3
0
    def next_example(self, example=None):
        # the mini-batch only contains one example due to batch_size == 1
        if example is None:
            self.raw_example_batch = next(self.env_data_iter)
        else:
            self.raw_example_batch = Batch(data=[example],
                                           dataset=self.train_set,
                                           device=self.device)

        # reserve example ids to track back to data sources
        self.raw_example_ids = self.raw_example_batch.Id.tolist()

        # get raw prediction probability
        _, input_batch, input_lengths, _ = self.prepare_relation_mini_batch(
            self.device, self.raw_example_batch)
        self.net.eval()
        with torch.no_grad():
            input_emb_seq, pred_logp = self.net.get_init_state_info(
                input_batch, input_lengths)

        # set input embedding (word embedding + position embedding)
        self.input_emb_seq = input_emb_seq
        self.input_seq_lens = input_lengths

        # set raw prediction log-probability for all classes
        self.raw_tgt_logp = pred_logp[:, 1]
        self.raw_pred_logps = pred_logp.tolist()
예제 #4
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                batch_size: int = len(batch)

                dataset = Dataset(examples=self.dataset.examples, fields=self.fields)
                if self.train:
                    yield from self.__bptt_2_minibatch(batch, batch_size, dataset)
                else:
                    if 'bow' in dataset.fields.keys():
                        batch.target_bow = batch.bow[:, :, 1]
                        batch.bow = batch.bow[:, :, 0]
                    else:
                        batch.text = (batch.text[0], batch.text[2])
                    batch.target_time = batch.time[0][:, :, -1]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    yield batch

            if not self.repeat:
                return
예제 #5
0
def create_homogenous_batches(dataset: Dataset, max_batch_size: int,
                              key_fn=lambda ex: -len(ex.src),
                              filter_fn=lambda ex: len(ex.src) > 0,
                              sort: bool = True) -> List[Batch]:
    """
    Creates a list of batches such that for each batch b it holds that:
        - b contains at least one and at most max_batch_size examples
        - for any two examples e1, e2 in b: key_fn(e1) == key_fn(e2)
        - b does not contain any example e for which filter_fn(e) == False
    In addition, the batches b are sorted by (increasing) key_fn(e) for any e in b.

    Args:
        dataset: the dataset to take the batches from
        max_batch_size: how many examples one batch may contain at the most
        key_fn: function of type (Example) -> int, that is used to sort the batches. Each batch will only have examples
            that all have exactly the same key (e.g. source sentence length).
        filter_fn: function of type (Example) -> bool, that is used to filter the examples. No example e with
            filter_fn(e) == False will be contained in any batch.
        sort: whether or not to sort the examples (by the given key_fn)

    Returns: a list of batches with the above properties
    """
    sorted_examples = sorted(dataset.examples, key=key_fn) if sort else dataset.examples

    same_key_blocks = []
    previous_key = -1
    current_block = []

    for example in sorted_examples:
        if not filter_fn(example):
            continue

        key = key_fn(example)
        if previous_key == -1 or key != previous_key:
            previous_key = key
            # start a new block
            if len(current_block) > 0:
                same_key_blocks.append(current_block)
                current_block = []
        # append current example to corresponding block
        current_block.append(example)

    # append last block
    if len(current_block) > 0:
        same_key_blocks.append(current_block)

    # split up blocks in batches of size at most max_batch_size
    batches = []
    for block in same_key_blocks:
        i = 0
        while i < len(block):
            # take the next at most max_batch_size examples from this block
            data = block[i:i + max_batch_size]
            batches.append(Batch(data=data, dataset=dataset))
            i += len(data)

    return batches
예제 #6
0
def get_interactive_batch(args, vid, question, iters):
    batch = []
    for key, it in iters.items():
        for ex in it.it.dataset.examples:
            if ex.vid == vid:
                batch.append((key, ex))
    tokenizer = get_tokenizer(args)
    ex = batch[check_que_sim(tokenizer, question, batch)]
    key, ex = ex
    ex = Batch([ex], iters[key].it.dataset)
    ex.images = iters[key].get_image(ex.vid)

    return ex
예제 #7
0
 def __iter__(self):
     while True:
         self.init_epoch()
         for idx, minibatch in enumerate(self.batches):
             # fast-forward if loaded from state
             if self._iterations_this_epoch > idx:
                 continue
             self.iterations += 1
             self._iterations_this_epoch += 1
             if self.sort_within_batch:
                 # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                 # be sorted by decreasing order, which requires reversing
                 # relative to typical sort keys
                 if self.sort:
                     minibatch.reverse()
                 else:
                     minibatch.sort(key=self.sort_key, reverse=True)
             yield minibatch, Batch(minibatch, self.dataset, self.device)
         if not self.repeat:
             return
예제 #8
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch_size: int = len(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('mark', self.dataset.fields['mark']),
                    ('target_time', Field(use_vocab=False)), ('target_mark', Field(use_vocab=False))])
                if self.train:
                    seq_len, time, mark = self.__series_2_bptt(batch, batch_size)
                    yield (Batch.fromvars(
                            dataset, batch_size,
                            time=(ti[:, :, :2], l),
                            mark=m[:, :, 0],
                            target_time=ti[:, :, -1],
                            target_mark=m[:, :, -1]) for ti, l, m in zip(time, seq_len, mark))
                else:
                    batch.target_time = batch.time[0][:, :, 2]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    batch.target_mark = batch.mark[:, :, 1]
                    batch.mark = batch.mark[:, :, 0]
                    yield batch

            if not self.repeat:
                return