async def get(self) -> MutableSequence[Token]: while True: # Check if some complete input sets are available for tag in list(self.token_values.keys()): if len(self.token_values[tag]) == len(self.ports): return flatten_list(self.token_values.pop(tag)) # Retrieve input tokens inputs = await asyncio.gather(*[ asyncio.create_task(port.get()) for port in self.ports.values() ]) # Check for termination for token in inputs: # If a TerminationToken is received, the corresponding port terminated its outputs if utils.check_termination(token): return [TerminationToken(self.name)] elif isinstance(token, MutableSequence): for t in token: if t.tag not in self.token_values: self.token_values[t.tag] = [] self.token_values[t.tag].append(t) elif isinstance(token, Token): if token.tag not in self.token_values: self.token_values[token.tag] = [] self.token_values[token.tag].append(token)
def terminate(self, status: Status): if not self.terminated: # Add a TerminationToken to each output port for port in self.output_ports.values(): port.put(TerminationToken(name=port.name)) self.status = status self.terminated = True logger.info("Step {name} terminated with status {status}".format( name=self.name, status=status.name))
async def _inject_inputs(self, step: Step, job: Job): for port_name, port in step.input_ports.items(): if port.dependee is None: output_port = DefaultOutputPort(name=port_name) output_port.step = step output_port.token_processor = port.token_processor command_output = JupyterCommandOutput(value=None, status=Status.COMPLETED, user_ns=self.user_ns) output_port.put(await output_port.token_processor.compute_token( job, command_output)) output_port.put(TerminationToken(port_name)) port.dependee = output_port
async def get(self, consumer: Text) -> Token: outputs = await self._retrieve(consumer) # Check for termination if utils.check_termination(outputs): return TerminationToken(self.name) # Return token outputs = flatten_list(outputs) if self.merge_strategy is not None: outputs = self._merge(outputs) if isinstance(outputs, MutableSequence): return Token(name=self.name, job=[t.job for t in outputs], value=outputs, tag=get_tag(outputs), weight=sum([t.weight for t in outputs])) else: return outputs
async def _cartesian_multiplier(self): input_tasks = [] for port_name, port in self.ports.items(): input_tasks.append(asyncio.create_task(port.get(), name=port_name)) while True: finished, unfinished = await asyncio.wait( input_tasks, return_when=FIRST_COMPLETED) input_tasks = list(unfinished) for task in finished: task_name = cast(Task, task).get_name() token = task.result() # If a TerminationToken is received, the corresponding port terminated its outputs if (isinstance(token, TerminationToken) or (isinstance(token, MutableSequence) and utils.check_termination(token))): self.terminated.append(task_name) # When the last port terminates, the entire combinator terminates if len(self.terminated) == len(self.ports): self.queue.put_nowait([TerminationToken(self.name)]) return else: # Get all combinations of the new element with the others list_of_lists = [] token_job = _get_job_name(token) if token_job not in self.token_lists: self.token_lists[token_job] = {} for port_name in self.ports: self.token_lists[token_job][port_name] = [] for name, token_list in self.token_lists[token_job].items( ): if name == task_name: list_of_lists.append([token]) else: list_of_lists.append(token_list) cartesian_product = list(itertools.product(*list_of_lists)) # Put all combinations in the queue for element in cartesian_product: self.queue.put_nowait(list(element)) # Put the new token in the related list self.token_lists[token_job][task_name].append(token) # Create a new task in place of the completed one input_tasks.append( asyncio.create_task(self.ports[task_name].get(), name=task_name))
async def _initialize(self): # Initialize token lists for port in self.ports.values(): self.token_lists[port.name] = [] # Retrieve initial input tokens input_tasks = [] for port in self.ports.values(): input_tasks.append(asyncio.create_task(port.get())) inputs = { k: v for (k, v) in zip(self.ports.keys(), await asyncio.gather( *input_tasks)) } # Check for early termination and put a TerminationToken if utils.check_termination(list(inputs.values())): self.queue.put_nowait([TerminationToken(self.name)]) # Otherwise put initial inputs in token lists and in queue and start cartesian multiplier else: for name, token in inputs.items(): self.token_lists[name].append(token) self.queue.put_nowait(list(inputs.values())) asyncio.create_task(self._cartesian_multiplier())
async def _get(self, consumer: Text) -> None: tasks = [] for port_name, port in self.ports.items(): tasks.append( asyncio.create_task(port.get(consumer), name=port_name)) while True: finished, unfinished = await asyncio.wait( tasks, return_when=FIRST_COMPLETED) tasks = list(unfinished) for task in finished: task_name = cast(Task, task).get_name() token = task.result() if not utils.check_termination(token): self.queues[consumer].put_nowait(token) tasks.append( asyncio.create_task( self.ports[task_name].get(consumer), name=task_name)) if len(tasks) == 0: self.queues[consumer].put_nowait( TerminationToken(name=self.name)) return