def test_return_all_elements(self): for use_threads in (True, False): with self.subTest('useThread={%s}' % use_threads): for size in [100, 10000, 100000]: dataset = list( MultiWorkerCallableIterator( ((i, ) for i in range(size)), identity, use_threads=use_threads)) self.assertSetEqual( set(dataset), set(range(size)), f'Some returned elements are missing.')
def read_data_chunks(data_chunk_paths: Iterable[RichPath], shuffle_chunks: bool=False, max_queue_size: int=1, num_workers: int=0) \ -> Iterable[List[Dict[str, Any]]]: if shuffle_chunks: data_chunk_paths = list(data_chunk_paths) np.random.shuffle(data_chunk_paths) if num_workers <= 0: for data_chunk_path in data_chunk_paths: yield data_chunk_path.read_by_file_suffix() else: def read_chunk(data_chunk_path: RichPath): return data_chunk_path.read_by_file_suffix() yield from MultiWorkerCallableIterator(argument_iterator=[(data_chunk_path,) for data_chunk_path in data_chunk_paths], worker_callable=read_chunk, max_queue_size=max_queue_size, num_workers=num_workers, use_threads=True)
def test(self, test_raw_data_chunk_paths: List[RichPath], beam_size: int = 5, per_result_callback_fn: Optional[Callable[ [int, float, Dict[str, Any], ModelTestResult], None]] = None, train_model=None) -> int: def read_chunk(raw_data_chunk_path: RichPath): return raw_data_chunk_path.read_by_file_suffix() data_chunk_iterator = \ MultiWorkerCallableIterator(argument_iterator=[(data_chunk_path,) for data_chunk_path in test_raw_data_chunk_paths], worker_callable=read_chunk, max_queue_size=3, num_workers=2, use_threads=True) sample_idx = 0 for raw_data_chunk in data_chunk_iterator: for raw_sample in raw_data_chunk: sample_idx += 1 loaded_train_sample = dict() loaded_train_sample['Provenance'] = raw_sample[ 'Filename'] + "::" + raw_sample['HoleLineSpan'] prod_root_node = min( int(v) for v in raw_sample['Productions'].keys()) sample_token_seq = [] collect_token_seq(raw_sample, prod_root_node, sample_token_seq) if len(raw_sample['VariableUsageContexts']) == 0: assert len(raw_sample['LastUseOfVariablesInScope']) == 0 continue loaded_test_sample = dict(loaded_train_sample) use_example = self._load_data_from_sample( self.hyperparameters, self.metadata, raw_sample=raw_sample, result_holder=loaded_train_sample, is_train=True) if not use_example: continue # Step (1): Compute perplexity: train_feed_dict = next( train_model._data_to_minibatches(loaded_train_sample, is_train=True))[0] sample_log_prob = train_model.sess.run( train_model._decoder_model.ops['log_probs'], feed_dict=train_feed_dict) token_perplexity = np.exp(-sample_log_prob / len(sample_token_seq)) # Step (2): Compute accuracy: self._load_data_from_sample(self.hyperparameters, self.metadata, raw_sample=raw_sample, result_holder=loaded_test_sample, is_train=False) test_feed_dict = self._tensorise_one_test_sample( loaded_test_sample) test_feed_dict[self.__placeholders['batch_size']] = 1 test_sample_encoding, context_encoding = self._encode_one_test_sample( test_feed_dict) if context_encoding is None: # TODO: Hack that should go away test_result = self._decoder_model.generate_suggestions_for_one_sample( loaded_test_sample, raw_sample, sample_idx, test_sample_encoding, beam_size=beam_size) # type: ModelTestResult else: test_result = self._decoder_model.generate_suggestions_for_one_sample( loaded_test_sample, raw_sample, sample_idx, test_sample_encoding, beam_size=beam_size, context_tokens=loaded_test_sample.get( 'context_nonkeyword_tokens'), context_token_representations=context_encoding[0], context_token_mask=test_feed_dict[self.placeholders[ 'context_token_mask']]) # type: ModelTestResult if per_result_callback_fn is not None: per_result_callback_fn(sample_idx, token_perplexity, raw_sample, test_result) return sample_idx