예제 #1
0
 def test_return_all_elements(self):
     for use_threads in (True, False):
         with self.subTest('useThread={%s}' % use_threads):
             for size in [100, 10000, 100000]:
                 dataset = list(
                     MultiWorkerCallableIterator(
                         ((i, ) for i in range(size)),
                         identity,
                         use_threads=use_threads))
                 self.assertSetEqual(
                     set(dataset), set(range(size)),
                     f'Some returned elements are missing.')
예제 #2
0
def read_data_chunks(data_chunk_paths: Iterable[RichPath], shuffle_chunks: bool=False, max_queue_size: int=1, num_workers: int=0) \
        -> Iterable[List[Dict[str, Any]]]:
    if shuffle_chunks:
        data_chunk_paths = list(data_chunk_paths)
        np.random.shuffle(data_chunk_paths)
    if num_workers <= 0:
        for data_chunk_path in data_chunk_paths:
            yield data_chunk_path.read_by_file_suffix()
    else:
        def read_chunk(data_chunk_path: RichPath):
            return data_chunk_path.read_by_file_suffix()
        yield from MultiWorkerCallableIterator(argument_iterator=[(data_chunk_path,) for data_chunk_path in data_chunk_paths],
                                               worker_callable=read_chunk,
                                               max_queue_size=max_queue_size,
                                               num_workers=num_workers,
                                               use_threads=True)
예제 #3
0
    def test(self,
             test_raw_data_chunk_paths: List[RichPath],
             beam_size: int = 5,
             per_result_callback_fn: Optional[Callable[
                 [int, float, Dict[str, Any], ModelTestResult], None]] = None,
             train_model=None) -> int:
        def read_chunk(raw_data_chunk_path: RichPath):
            return raw_data_chunk_path.read_by_file_suffix()
        data_chunk_iterator = \
            MultiWorkerCallableIterator(argument_iterator=[(data_chunk_path,) for data_chunk_path in test_raw_data_chunk_paths],
                                        worker_callable=read_chunk,
                                        max_queue_size=3,
                                        num_workers=2,
                                        use_threads=True)

        sample_idx = 0
        for raw_data_chunk in data_chunk_iterator:
            for raw_sample in raw_data_chunk:
                sample_idx += 1

                loaded_train_sample = dict()
                loaded_train_sample['Provenance'] = raw_sample[
                    'Filename'] + "::" + raw_sample['HoleLineSpan']
                prod_root_node = min(
                    int(v) for v in raw_sample['Productions'].keys())
                sample_token_seq = []
                collect_token_seq(raw_sample, prod_root_node, sample_token_seq)
                if len(raw_sample['VariableUsageContexts']) == 0:
                    assert len(raw_sample['LastUseOfVariablesInScope']) == 0
                    continue
                loaded_test_sample = dict(loaded_train_sample)
                use_example = self._load_data_from_sample(
                    self.hyperparameters,
                    self.metadata,
                    raw_sample=raw_sample,
                    result_holder=loaded_train_sample,
                    is_train=True)
                if not use_example:
                    continue

                # Step (1): Compute perplexity:
                train_feed_dict = next(
                    train_model._data_to_minibatches(loaded_train_sample,
                                                     is_train=True))[0]
                sample_log_prob = train_model.sess.run(
                    train_model._decoder_model.ops['log_probs'],
                    feed_dict=train_feed_dict)
                token_perplexity = np.exp(-sample_log_prob /
                                          len(sample_token_seq))

                # Step (2): Compute accuracy:
                self._load_data_from_sample(self.hyperparameters,
                                            self.metadata,
                                            raw_sample=raw_sample,
                                            result_holder=loaded_test_sample,
                                            is_train=False)
                test_feed_dict = self._tensorise_one_test_sample(
                    loaded_test_sample)
                test_feed_dict[self.__placeholders['batch_size']] = 1
                test_sample_encoding, context_encoding = self._encode_one_test_sample(
                    test_feed_dict)
                if context_encoding is None:  # TODO: Hack that should go away
                    test_result = self._decoder_model.generate_suggestions_for_one_sample(
                        loaded_test_sample,
                        raw_sample,
                        sample_idx,
                        test_sample_encoding,
                        beam_size=beam_size)  # type: ModelTestResult
                else:
                    test_result = self._decoder_model.generate_suggestions_for_one_sample(
                        loaded_test_sample,
                        raw_sample,
                        sample_idx,
                        test_sample_encoding,
                        beam_size=beam_size,
                        context_tokens=loaded_test_sample.get(
                            'context_nonkeyword_tokens'),
                        context_token_representations=context_encoding[0],
                        context_token_mask=test_feed_dict[self.placeholders[
                            'context_token_mask']])  # type: ModelTestResult

                if per_result_callback_fn is not None:
                    per_result_callback_fn(sample_idx, token_perplexity,
                                           raw_sample, test_result)
        return sample_idx