def visit_With(self, node: Any) -> None: """Visit a with statement Only the `retry()` context manager is supported. Examples: with retry( on_exceptions=["CustomError1", CustomError2"], interval=10 ): TaskToRetry() """ context_manager = node.items[0].context_expr if context_manager.func.id == "retry": assert_supported_operation( len(node.body) == 1, "The retry context manager can only wrap a single task", node, ) self.visit(node.body[0]) task = self._current_state task.add_retries(context_manager) else: raise UnsupportedOperation( """Supported context managers include: * ``retry()`` for retrying tasks """, node, )
def visit_ExceptHandler(self, node: Any) -> None: """Visit a exception handler The try statement is parsed in a different method. Examples: except BadThing: TaskToHandleBadThing() except States.Timeout: TaskToPingSlack() """ logging.debug("Visiting an ExceptHandler") task = self._current_state assert_supported_operation( isinstance(task, TaskState), "Only task states can have exception handlers", node, ) catch = task.add_catch(node) self._set_current_state(catch) for item in node.body: self.visit(item) # Repoint the current state at the task so the next ExceptHandler # can be appended self._set_current_state(task)
def to_dict(self) -> Dict: """Return a serialized representation of the ChoiceBranch""" data = {} self._set_end_or_next(data) branch = self._serialize(self.ast_node.test) assert_supported_operation( isinstance(branch, dict), "Invalid conditional statement. Most likely there is no comparison or" " boolean logic present.", self.ast_node.test, ) data.update(branch) return data
def visit_AsyncFunctionDef(self, node: Any) -> None: """Visit an async function definition. This validates that only a certain set of methods are defined, and they have the proper signatures. """ assert_supported_operation( node.name.startswith("get_custom_tags_"), f"Custom event processors can only implement get_custom_tags_* methods. Provided: {node.name}", node, ) arg_name_set = {arg.arg for arg in node.args.args} assert_supported_operation( arg_name_set == {"message", "input_data", "state_data_client"}, "get_custom_tags_* methods must only accept positional arguments of" " (message, input_data, state_data_client)." f" Provided: {', '.join([arg.arg for arg in node.args.args])}", node, )
def _parse_result_path(self) -> str: """Parse the result path from the AST node. The result path is where the result value will be inserted into the data object. Returns: result path string """ if isinstance(self.ast_node, ast.Assign): result_path = convert_input_data_ref(self.ast_node.targets[0]) assert_supported_operation( re.search(INVALID_RESULT_PATH_PATTERN, result_path) is None, "Task result path is invalid. Check that it does not contain reserved" f" keys: {', '.join(RESERVED_INPUT_DATA_KEYS)}", self.ast_node, ) return result_path return "$"
def visit_ClassDef(self, node) -> None: """Visit the class definition. This parses class attributes then recurs into the class methods. """ for item in node.body: if isinstance(item, ast.Assign): key = item.targets[0].id if key in ATTRIBUTE_MAP: attribute = ATTRIBUTE_MAP[key] value = attribute.get_value(item.value, visitor=self) if attribute.allowed_values is not None: assert_supported_operation( value in attribute.allowed_values, f"Allowed values for class attribute {key} include:" f" {', '.join([str(value) for value in attribute.allowed_values])}", node, ) self.attributes[key] = value self.generic_visit(node)
def shape(self) -> None: """Shape the graph for this Choice state node.""" # Collect a list of nodes that are adjacent to this choice node in the graph # but are not linked via an edge with the ``in_else`` flag set to true. These # nodes represent the next state *after* the choice state. non_else_nodes = [ node for node, edge_attrs in self.state_graph.adj[self].items() if not edge_attrs.get("in_else") ] assert_supported_operation( len(non_else_nodes) <= 1, "A maximum of 1 state can be downstream from a Choice state", self.ast_node, ) if len(non_else_nodes) == 1: # For each choice branch, if it ends with a non-terminal state then add an # edge to the next state *after* the choice state. for branch in self.choice_branches: descendants = branch.descendants if len(descendants) > 0 and not descendants[-1].TERMINAL: self.state_graph.add_edge(descendants[-1], non_else_nodes[0])
def _get_result_path(self) -> Optional[str]: """Get the ResultPath value for the task state If the result path was not explicitly provided, return None to indicate that the the result should be discarded. Returns: result path string or None """ if isinstance(self.ast_node, ast.Assign) and len(self.ast_node.targets) > 0: result_path = convert_input_data_ref(self.ast_node.targets[0]) assert_supported_operation( re.search(INVALID_RESULT_PATH_PATTERN, result_path) is None, "Task result path is invalid. Check that it does not contain reserved" f" keys: {', '.join(RESERVED_INPUT_DATA_KEYS)}", self.ast_node, ) return result_path return None
def to_dict(self): """Return a serialized representation of the Choice state.""" data = { "Type": "Choice", "Choices": [c.to_dict() for c in self.choice_branches], } # Collect a list of nodes that are adjacent to this choice node in the graph # that are linked via an edge with the ``in_else`` flag set to true. These # nodes represent the Default choice if none of the branches trigger. else_nodes = [ node for node, edge_attrs in self.state_graph.adj[self].items() if edge_attrs.get("in_else") ] assert_supported_operation( len(else_nodes) <= 1, "A maximum of 1 state can be included in an `else` clause", self.ast_node, ) if len(else_nodes) == 1: data["Default"] = else_nodes[0].key return data
def parse_options( option_map: Dict[str, CallableOption], node: ast.Call, visitor: Optional[ast.NodeVisitor] = None, ) -> OptionsMap: """Parse options from keyword arguments passed to a Callable. See :py:class:`CallableOption` Args: option_map: Map of keyword argument name to the CallableOption schema node: AST node of the task state being added to the state machine visitor: Instance of a node visitor to provide extra context for parsing options Returns: OptionsMap instance, which is a dict of key-values with defaults filled in """ options = OptionsMap(node) options.update({ key: option.default_value(node, visitor) if callable(option.default_value) else option.default_value for key, option in option_map.items() }) for keyword in node.keywords: key = keyword.arg assert_supported_operation( key in option_map, f"Invalid keyword argument. Options: {', '.join(option_map.keys())}", node, ) option = option_map[key] if option.value_type: assert_supported_operation( isinstance(keyword.value, option.value_type), f"Invalid data type for the {key} option:" f" expected a {option.value_type_label}.", node, ) value = option.get_value(keyword.value, visitor) options[key] = value # Ensure that all required options were provided required_options = { key for key, option in option_map.items() if option.required } provided_options = {keyword.arg for keyword in node.keywords} missing_options = required_options - provided_options assert_supported_operation( len(missing_options) == 0, f"The following options are required but were not provided: {', '.join(missing_options)}", node, ) return options
def visit_FunctionDef(self, node: Any) -> None: """Visit a function definition. Every module-level function is, to start, considered a full state machine. We'll determine later on whether the state machine is an embedded parallel branch, map iterator, or standalone state machine. """ # Parse state machine options from the list of function decorators options = defaultdict(list) for decorator in node.decorator_list: key = decorator.func.id assert_supported_operation( isinstance(decorator, ast.Call) and key in RESOURCE_DECORATOR_MAP, "Supported resource decorators include:" f" {', '.join(RESOURCE_DECORATOR_MAP.keys())}", decorator, ) decorator_config = RESOURCE_DECORATOR_MAP[key] options[key].append( parse_options(decorator_config["options"], decorator, visitor=self) ) max_count = decorator_config.get("max_count") assert_supported_operation( max_count is None or (max_count is not None and len(options[key]) <= max_count), f"Only {max_count} @{key} decorators can be applied to a" " state machine function", node, ) visitor = StateMachineVisitor( node.name, self.task_visitors, self.state_machine_visitors, **options ) visitor.visit(node) visitor.shape_nodes() self.state_machine_visitors[node.name] = visitor
def visit_Try(self, node: Any) -> None: """Visit a try statement The exception handler nodes are parsed in a different method. Examples:: try: MyTask() """ logging.debug("Visiting a Try") assert_supported_operation( len(node.body) == 1, "Only a single task statement at a time can have exception handling applied", node, ) assert_supported_operation( len(node.orelse) == 0, "The `else` part of a try/except block is not currently supported", node, ) assert_supported_operation( len(node.finalbody) == 0, "The `finally` part of a try/except block is not currently supported", node, ) assert_supported_operation( len(node.handlers) > 0, "At least 1 exception handler is required", node) for item in node.body: self.visit(item) for handler in node.handlers: self.visit(handler)
def _serialize(self, node: Any) -> Any: """Recursive function to serialize a part of the choice branch AST node. This looks at each part of the conditional statement's AST to determine the correct ASL representation. """ if isinstance(node, ast.BoolOp): # e.g. ``___ and ___`` return { OP_TO_KEY[node.op.__class__]: [ self._serialize(value) for value in node.values ] } elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): # e.g. ``not bool(data["foo"])`` if isinstance(node.operand, ast.Call) and node.operand.func.id == "bool": value = { "Variable": self._serialize(node.operand), OP_TO_KEY[(ast.Eq, bool)]: True, } else: value = self._serialize(node.operand) return {"Not": value} elif isinstance(node, ast.Compare): # e.g. ``data["foo"] > 0`` assert_supported_operation( len(node.ops) == 1, "Only 1 comparison operator at a time is allowed", node, ) assert_supported_operation( len(node.comparators) == 1, "Only 1 comparator at a time is allowed", node, ) op = node.ops[0] comparator = node.comparators[0] # Determine the data type of the choice type_ = None if isinstance(node.left, ast.Call): type_ = node.left.func.id if isinstance(comparator, ast.Call): assert_supported_operation( type_ is None or (type_ is not None and comparator.func.id == type_), f"Value types must match. Found: {type_} {op.__class__}" f" {comparator.func.id}", node, ) if type_ is None: type_ = comparator.func.id elif isinstance(comparator, ast.Str): type_ = "str" elif isinstance(comparator, ast.Num): type_ = "float" elif isinstance(comparator, ast.NameConstant) and comparator.value in ( True, False, ): type_ = "bool" elif isinstance(comparator, ast.Subscript): raise UnsupportedOperation( "Input data cannot be used as the comparator (right side of operation)", comparator, ) else: raise UnsupportedOperation( "Could not determine data type for choice variable", comparator ) type_class = TYPE_TO_CLASS[type_] if isinstance(op, ast.NotEq): return { "Not": { "Variable": self._serialize(node.left), OP_TO_KEY[(ast.Eq, type_class)]: self._serialize(comparator), } } return { "Variable": self._serialize(node.left), OP_TO_KEY[(op.__class__, type_class)]: self._serialize(comparator), } elif isinstance(node, ast.Subscript): # e.g. ``data["foo"]`` return convert_input_data_ref(node) elif isinstance(node, ast.Call): # e.g. ``int(data["foo"])`` assert_supported_operation( node.func.id in TYPE_TO_CLASS, f"Function {node.func.id} is not supported. Allowed built-ins: " ", ".join(TYPE_TO_CLASS.keys()), node, ) assert_supported_operation( len(node.args) == 1, "Data type casting functions only accept 1 positional argument", node, ) if node.func.id == "bool": return { "Variable": self._serialize(node.args[0]), OP_TO_KEY[(ast.Eq, bool)]: True, } else: return self._serialize(node.args[0]) elif isinstance(node, ast.NameConstant): assert_supported_operation( node.value is not None, "The value `None` is not allowed in Choice states", node, ) return node.value elif isinstance(node, ast.Str): return node.s elif isinstance(node, ast.Num): return node.n raise UnsupportedOperation("Unsupported choice branch logic", node)
def visit_If(self, node: Any) -> None: """Visit an if statement If statements will have a single if statement, zero or more elif statements, and zero or one else statements. In the AST these are recursively nested within each node but we need to build a flat list. Examples:: if data["foo"] > 0: MyTaskWhenPositive() elif data["foo"] < 0: MyTaskWhenNegative() else: MyTaskWhenZero() """ logging.debug("Visiting an If") current_choice_state = self._pop_choice_state_stack() if isinstance(current_choice_state, ChoiceState) and not self._in_choice_body: logging.debug(f"Current ChoiceState is {current_choice_state}") choice_branch = current_choice_state.add_choice_branch(node) self._set_current_state(choice_branch) else: logging.debug( f"Creating new ChoiceState (_in_choice_body={self._in_choice_body})" ) state = ChoiceState(self.state_graph, f"Choice-{hash_node(node)}", node) self._add_state(state) self._set_current_state(state.current_choice_branch) current_choice_state = self._push_choice_state_stack(state) logging.debug("Setting _in_choice_body=true") self._in_choice_body = True for item in node.body: self.visit(item) logging.debug("Setting _in_choice_body=false") self._in_choice_body = False if (len(node.orelse) > 0 and isinstance(node.orelse[0], ast.If)) or len( node.orelse) == 0: logging.debug("More elif conditions to parse") # Append the same state object so we can use it for the next branch # of the choice current_choice_state = self._push_choice_state_stack( current_choice_state) else: logging.debug("Parsing final elif choice branch") self._set_current_state(current_choice_state) logging.debug("Parsing else choice branch") self._in_else = True assert_supported_operation( len(node.orelse) <= 1, "A maximum of 1 state can be included in an `else` clause", node, ) for item in node.orelse: self.visit(item) self._in_else = False if isinstance(current_choice_state, ChoiceState): # This means we're done building the Choice state because otherwise, we'd # be adding things to a ChoiceStateBranch. Point the current state back to # the Choice object so that it can point to the Next state. self._set_current_state(current_choice_state)
def visit_Expr(self, node: Any) -> None: """Visit expression nodes. Expressions are function calls that aren't assigned to a variable. These include: * ``update()`` for Pass states * ``parallel()`` for Parallel states * ``wait`` for Wait states * ``map`` for Map states * Instantiating Task classes Examples:: data.update({"hello": "world"}) parallel(branch1, branch2) wait(seconds=10) map(data["items"], item_iterator) Foo() """ logging.debug("Visiting an Expr") if isinstance(node.value, ast.Str): # This is probably a docstring return if isinstance(node.value.func, ast.Attribute): assert_supported_operation( node.value.func.value.id == "data" and node.value.func.attr == "update", "The only supported method call is `data.update()` to set values on the input data", node, ) state = PassState(self.state_graph, f"Pass-{hash_node(node, self.name)}", node) self._add_state(state) self._set_current_state(state) elif node.value.func.id == "parallel": assert_supported_operation( len(node.value.args) > 0, "At least one branch function must be provided to the parallel state.", node, ) state = ParallelState(self.state_graph, f"Parallel-{hash_node(node)}", node) self._add_state(state) self._set_current_state(state) for arg in node.value.args: assert_supported_operation( isinstance(arg, ast.Name) and arg.id in self._state_machine_visitors, "Only defined functions can be provided to the parallel state." f" Available functions: {', '.join(self._other_state_machine_names)}", node, ) state.add_branch(self._state_machine_visitors[arg.id]) # The referenced state machine is used as a parallel branch so we'll # demote it self._state_machine_visitors[arg.id].is_first_class = False elif node.value.func.id == "wait": state = WaitState(self.state_graph, f"Wait-{hash_node(node)}", node) self._add_state(state) self._set_current_state(state) elif node.value.func.id == "map": args = node.value.args assert_supported_operation( len(args) == 2, "Map state requires two arguments: a list of items from data and an" " iterator function", node, ) _, iterator = args assert_supported_operation( isinstance(iterator, ast.Name) and iterator.id in self._state_machine_visitors, "Only defined functions can be provided to the map state." f" Available functions: {', '.join(self._other_state_machine_names)}", node, ) state = MapState( self.state_graph, f"Map-{hash_node(node)}", node, self._state_machine_visitors[iterator.id], ) self._add_state(state) self._set_current_state(state) # The referenced state machine is used as an iterator so we'll demote it self._state_machine_visitors[iterator.id].is_first_class = False self._state_machine_visitors[iterator.id].is_map_iterator = True elif node.value.func.id in self._task_visitors: # This node is instantiating a task class state = create_task_state(self, node, self._task_visitors[node.value.func.id]) self._add_state(state) self._set_current_state(state) elif node.value.func.id in self._state_machine_visitors: # This node is nesting a state machine state = create_task_state( self, node, self._state_machine_visitors[node.value.func.id]) self._add_state(state) self._set_current_state(state) else: raise UnsupportedOperation( """Supported expressions include: * ``update()`` for Pass states * ``parallel()`` for Parallel states * ``wait`` for Wait states * Instantiating Task classes * Nesting state machines""", node, )
def visit_Assign(self, node: Any) -> None: """Visit assignment nodes. Assignments are used to create Task and Pass states. Examples:: data["result"] = Foo() data["result"] = map(data["items"], item_iterator) data["result"] = {"hello": "world"} """ assert_supported_operation( len(node.targets) == 1, "Value assignments can only target one variable", node, ) source = astor.to_source(node).strip() logging.debug(f"Visiting Assign ({source})") target = node.targets[0] assert_supported_operation( isinstance(target, ast.Subscript) and target.value.id == "data", "Assignment target must be a key on `data`", node, ) if (isinstance(node.value, ast.Call) and node.value.func.id in self._task_visitors): # This node is instantiating a task class state = create_task_state(self, node, self._task_visitors[node.value.func.id]) self._add_state(state) self._set_current_state(state) elif (isinstance(node.value, ast.Call) and node.value.func.id in self._state_machine_visitors): # This node is nesting a state machine state = create_task_state( self, node, self._state_machine_visitors[node.value.func.id]) self._add_state(state) self._set_current_state(state) elif isinstance(node.value, ast.Call) and node.value.func.id == "map": # This node is instantiating a map state args = node.value.args assert_supported_operation( len(args) == 2, "Map state requires two arguments: a list of items from data and an" " iterator function", node, ) _, iterator = args assert_supported_operation( isinstance(iterator, ast.Name) and iterator.id in self._state_machine_visitors, "Only defined functions can be provided to the map state." f" Available functions: {', '.join(self._other_state_machine_names)}", node, ) state = MapState( self.state_graph, f"Map-{hash_node(node)}", node, self._state_machine_visitors[iterator.id], ) self._add_state(state) self._set_current_state(state) # The referenced state machine is used as an iterator so we'll demote it self._state_machine_visitors[iterator.id].is_first_class = False self._state_machine_visitors[iterator.id].is_map_iterator = True else: # This node is setting static data state = PassState(self.state_graph, f"Pass-{hash_node(node, self.name)}", node) self._add_state(state) self._set_current_state(state)