def visit_With(self, node: Any) -> None:
        """Visit a with statement

        Only the `retry()` context manager is supported.

        Examples:
            with retry(
                on_exceptions=["CustomError1", CustomError2"],
                interval=10
            ):
                TaskToRetry()

        """
        context_manager = node.items[0].context_expr
        if context_manager.func.id == "retry":
            assert_supported_operation(
                len(node.body) == 1,
                "The retry context manager can only wrap a single task",
                node,
            )
            self.visit(node.body[0])
            task = self._current_state
            task.add_retries(context_manager)
        else:
            raise UnsupportedOperation(
                """Supported context managers include:
* ``retry()`` for retrying tasks
""",
                node,
            )
    def visit_ExceptHandler(self, node: Any) -> None:
        """Visit a exception handler

        The try statement is parsed in a different method.

        Examples:
            except BadThing:
                TaskToHandleBadThing()
            except States.Timeout:
                TaskToPingSlack()

        """
        logging.debug("Visiting an ExceptHandler")
        task = self._current_state
        assert_supported_operation(
            isinstance(task, TaskState),
            "Only task states can have exception handlers",
            node,
        )
        catch = task.add_catch(node)
        self._set_current_state(catch)
        for item in node.body:
            self.visit(item)

        # Repoint the current state at the task so the next ExceptHandler
        # can be appended
        self._set_current_state(task)
 def to_dict(self) -> Dict:
     """Return a serialized representation of the ChoiceBranch"""
     data = {}
     self._set_end_or_next(data)
     branch = self._serialize(self.ast_node.test)
     assert_supported_operation(
         isinstance(branch, dict),
         "Invalid conditional statement. Most likely there is no comparison or"
         " boolean logic present.",
         self.ast_node.test,
     )
     data.update(branch)
     return data
Beispiel #4
0
    def visit_AsyncFunctionDef(self, node: Any) -> None:
        """Visit an async function definition.

        This validates that only a certain set of methods are defined, and they have the
        proper signatures.

        """
        assert_supported_operation(
            node.name.startswith("get_custom_tags_"),
            f"Custom event processors can only implement get_custom_tags_* methods. Provided: {node.name}",
            node,
        )
        arg_name_set = {arg.arg for arg in node.args.args}
        assert_supported_operation(
            arg_name_set == {"message", "input_data", "state_data_client"},
            "get_custom_tags_* methods must only accept positional arguments of"
            " (message, input_data, state_data_client)."
            f" Provided: {', '.join([arg.arg for arg in node.args.args])}",
            node,
        )
    def _parse_result_path(self) -> str:
        """Parse the result path from the AST node.

        The result path is where the result value will be inserted into the data object.

        Returns:
            result path string

        """
        if isinstance(self.ast_node, ast.Assign):
            result_path = convert_input_data_ref(self.ast_node.targets[0])
            assert_supported_operation(
                re.search(INVALID_RESULT_PATH_PATTERN, result_path) is None,
                "Task result path is invalid. Check that it does not contain reserved"
                f" keys: {', '.join(RESERVED_INPUT_DATA_KEYS)}",
                self.ast_node,
            )
            return result_path

        return "$"
Beispiel #6
0
    def visit_ClassDef(self, node) -> None:
        """Visit the class definition.

        This parses class attributes then recurs into the class methods.

        """
        for item in node.body:
            if isinstance(item, ast.Assign):
                key = item.targets[0].id
                if key in ATTRIBUTE_MAP:
                    attribute = ATTRIBUTE_MAP[key]
                    value = attribute.get_value(item.value, visitor=self)
                    if attribute.allowed_values is not None:
                        assert_supported_operation(
                            value in attribute.allowed_values,
                            f"Allowed values for class attribute {key} include:"
                            f" {', '.join([str(value) for value in attribute.allowed_values])}",
                            node,
                        )
                    self.attributes[key] = value

        self.generic_visit(node)
 def shape(self) -> None:
     """Shape the graph for this Choice state node."""
     # Collect a list of nodes that are adjacent to this choice node in the graph
     # but are not linked via an edge with the ``in_else`` flag set to true. These
     # nodes represent the next state *after* the choice state.
     non_else_nodes = [
         node
         for node, edge_attrs in self.state_graph.adj[self].items()
         if not edge_attrs.get("in_else")
     ]
     assert_supported_operation(
         len(non_else_nodes) <= 1,
         "A maximum of 1 state can be downstream from a Choice state",
         self.ast_node,
     )
     if len(non_else_nodes) == 1:
         # For each choice branch, if it ends with a non-terminal state then add an
         # edge to the next state *after* the choice state.
         for branch in self.choice_branches:
             descendants = branch.descendants
             if len(descendants) > 0 and not descendants[-1].TERMINAL:
                 self.state_graph.add_edge(descendants[-1], non_else_nodes[0])
Beispiel #8
0
    def _get_result_path(self) -> Optional[str]:
        """Get the ResultPath value for the task state

        If the result path was not explicitly provided, return None to indicate that
        the the result should be discarded.

        Returns:
            result path string or None

        """
        if isinstance(self.ast_node,
                      ast.Assign) and len(self.ast_node.targets) > 0:
            result_path = convert_input_data_ref(self.ast_node.targets[0])
            assert_supported_operation(
                re.search(INVALID_RESULT_PATH_PATTERN, result_path) is None,
                "Task result path is invalid. Check that it does not contain reserved"
                f" keys: {', '.join(RESERVED_INPUT_DATA_KEYS)}",
                self.ast_node,
            )
            return result_path

        return None
    def to_dict(self):
        """Return a serialized representation of the Choice state."""
        data = {
            "Type": "Choice",
            "Choices": [c.to_dict() for c in self.choice_branches],
        }

        # Collect a list of nodes that are adjacent to this choice node in the graph
        # that are linked via an edge with the ``in_else`` flag set to true. These
        # nodes represent the Default choice if none of the branches trigger.
        else_nodes = [
            node
            for node, edge_attrs in self.state_graph.adj[self].items()
            if edge_attrs.get("in_else")
        ]
        assert_supported_operation(
            len(else_nodes) <= 1,
            "A maximum of 1 state can be included in an `else` clause",
            self.ast_node,
        )
        if len(else_nodes) == 1:
            data["Default"] = else_nodes[0].key

        return data
Beispiel #10
0
def parse_options(
    option_map: Dict[str, CallableOption],
    node: ast.Call,
    visitor: Optional[ast.NodeVisitor] = None,
) -> OptionsMap:
    """Parse options from keyword arguments passed to a Callable.

    See :py:class:`CallableOption`

    Args:
        option_map: Map of keyword argument name to the CallableOption schema
        node: AST node of the task state being added to the state machine
        visitor: Instance of a node visitor to provide extra context for parsing options

    Returns:
        OptionsMap instance, which is a dict of key-values with defaults filled in

    """
    options = OptionsMap(node)
    options.update({
        key: option.default_value(node, visitor)
        if callable(option.default_value) else option.default_value
        for key, option in option_map.items()
    })
    for keyword in node.keywords:
        key = keyword.arg
        assert_supported_operation(
            key in option_map,
            f"Invalid keyword argument. Options: {', '.join(option_map.keys())}",
            node,
        )
        option = option_map[key]
        if option.value_type:
            assert_supported_operation(
                isinstance(keyword.value, option.value_type),
                f"Invalid data type for the {key} option:"
                f" expected a {option.value_type_label}.",
                node,
            )
        value = option.get_value(keyword.value, visitor)
        options[key] = value

    # Ensure that all required options were provided
    required_options = {
        key
        for key, option in option_map.items() if option.required
    }
    provided_options = {keyword.arg for keyword in node.keywords}
    missing_options = required_options - provided_options
    assert_supported_operation(
        len(missing_options) == 0,
        f"The following options are required but were not provided: {', '.join(missing_options)}",
        node,
    )

    return options
Beispiel #11
0
    def visit_FunctionDef(self, node: Any) -> None:
        """Visit a function definition.

        Every module-level function is, to start, considered a full state machine.
        We'll determine later on whether the state machine is an embedded parallel
        branch, map iterator, or standalone state machine.
        """
        # Parse state machine options from the list of function decorators
        options = defaultdict(list)
        for decorator in node.decorator_list:
            key = decorator.func.id
            assert_supported_operation(
                isinstance(decorator, ast.Call) and key in RESOURCE_DECORATOR_MAP,
                "Supported resource decorators include:"
                f" {', '.join(RESOURCE_DECORATOR_MAP.keys())}",
                decorator,
            )
            decorator_config = RESOURCE_DECORATOR_MAP[key]
            options[key].append(
                parse_options(decorator_config["options"], decorator, visitor=self)
            )
            max_count = decorator_config.get("max_count")
            assert_supported_operation(
                max_count is None
                or (max_count is not None and len(options[key]) <= max_count),
                f"Only {max_count} @{key} decorators can be applied to a"
                " state machine function",
                node,
            )

        visitor = StateMachineVisitor(
            node.name, self.task_visitors, self.state_machine_visitors, **options
        )
        visitor.visit(node)
        visitor.shape_nodes()
        self.state_machine_visitors[node.name] = visitor
    def visit_Try(self, node: Any) -> None:
        """Visit a try statement

        The exception handler nodes are parsed in a different method.

        Examples::

            try:
                MyTask()

        """
        logging.debug("Visiting a Try")
        assert_supported_operation(
            len(node.body) == 1,
            "Only a single task statement at a time can have exception handling applied",
            node,
        )
        assert_supported_operation(
            len(node.orelse) == 0,
            "The `else` part of a try/except block is not currently supported",
            node,
        )
        assert_supported_operation(
            len(node.finalbody) == 0,
            "The `finally` part of a try/except block is not currently supported",
            node,
        )
        assert_supported_operation(
            len(node.handlers) > 0, "At least 1 exception handler is required",
            node)

        for item in node.body:
            self.visit(item)

        for handler in node.handlers:
            self.visit(handler)
    def _serialize(self, node: Any) -> Any:
        """Recursive function to serialize a part of the choice branch AST node.

        This looks at each part of the conditional statement's AST to determine the
        correct ASL representation.
        """
        if isinstance(node, ast.BoolOp):
            # e.g. ``___ and ___``
            return {
                OP_TO_KEY[node.op.__class__]: [
                    self._serialize(value) for value in node.values
                ]
            }
        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
            # e.g. ``not bool(data["foo"])``
            if isinstance(node.operand, ast.Call) and node.operand.func.id == "bool":
                value = {
                    "Variable": self._serialize(node.operand),
                    OP_TO_KEY[(ast.Eq, bool)]: True,
                }
            else:
                value = self._serialize(node.operand)
            return {"Not": value}
        elif isinstance(node, ast.Compare):
            # e.g. ``data["foo"] > 0``
            assert_supported_operation(
                len(node.ops) == 1,
                "Only 1 comparison operator at a time is allowed",
                node,
            )
            assert_supported_operation(
                len(node.comparators) == 1,
                "Only 1 comparator at a time is allowed",
                node,
            )
            op = node.ops[0]
            comparator = node.comparators[0]
            # Determine the data type of the choice
            type_ = None
            if isinstance(node.left, ast.Call):
                type_ = node.left.func.id
            if isinstance(comparator, ast.Call):
                assert_supported_operation(
                    type_ is None
                    or (type_ is not None and comparator.func.id == type_),
                    f"Value types must match. Found: {type_} {op.__class__}"
                    f" {comparator.func.id}",
                    node,
                )
                if type_ is None:
                    type_ = comparator.func.id
            elif isinstance(comparator, ast.Str):
                type_ = "str"
            elif isinstance(comparator, ast.Num):
                type_ = "float"
            elif isinstance(comparator, ast.NameConstant) and comparator.value in (
                True,
                False,
            ):
                type_ = "bool"
            elif isinstance(comparator, ast.Subscript):
                raise UnsupportedOperation(
                    "Input data cannot be used as the comparator (right side of operation)",
                    comparator,
                )
            else:
                raise UnsupportedOperation(
                    "Could not determine data type for choice variable", comparator
                )

            type_class = TYPE_TO_CLASS[type_]
            if isinstance(op, ast.NotEq):
                return {
                    "Not": {
                        "Variable": self._serialize(node.left),
                        OP_TO_KEY[(ast.Eq, type_class)]: self._serialize(comparator),
                    }
                }

            return {
                "Variable": self._serialize(node.left),
                OP_TO_KEY[(op.__class__, type_class)]: self._serialize(comparator),
            }

        elif isinstance(node, ast.Subscript):
            # e.g. ``data["foo"]``
            return convert_input_data_ref(node)
        elif isinstance(node, ast.Call):
            # e.g. ``int(data["foo"])``
            assert_supported_operation(
                node.func.id in TYPE_TO_CLASS,
                f"Function {node.func.id} is not supported. Allowed built-ins: "
                ", ".join(TYPE_TO_CLASS.keys()),
                node,
            )
            assert_supported_operation(
                len(node.args) == 1,
                "Data type casting functions only accept 1 positional argument",
                node,
            )
            if node.func.id == "bool":
                return {
                    "Variable": self._serialize(node.args[0]),
                    OP_TO_KEY[(ast.Eq, bool)]: True,
                }
            else:
                return self._serialize(node.args[0])
        elif isinstance(node, ast.NameConstant):
            assert_supported_operation(
                node.value is not None,
                "The value `None` is not allowed in Choice states",
                node,
            )
            return node.value
        elif isinstance(node, ast.Str):
            return node.s
        elif isinstance(node, ast.Num):
            return node.n

        raise UnsupportedOperation("Unsupported choice branch logic", node)
    def visit_If(self, node: Any) -> None:
        """Visit an if statement

        If statements will have a single if statement, zero or more elif statements,
        and zero or one else statements. In the AST these are recursively nested within
        each node but we need to build a flat list.

        Examples::

            if data["foo"] > 0:
                MyTaskWhenPositive()
            elif data["foo"] < 0:
                MyTaskWhenNegative()
            else:
                MyTaskWhenZero()

        """
        logging.debug("Visiting an If")
        current_choice_state = self._pop_choice_state_stack()

        if isinstance(current_choice_state,
                      ChoiceState) and not self._in_choice_body:
            logging.debug(f"Current ChoiceState is {current_choice_state}")
            choice_branch = current_choice_state.add_choice_branch(node)
            self._set_current_state(choice_branch)
        else:
            logging.debug(
                f"Creating new ChoiceState (_in_choice_body={self._in_choice_body})"
            )
            state = ChoiceState(self.state_graph, f"Choice-{hash_node(node)}",
                                node)
            self._add_state(state)
            self._set_current_state(state.current_choice_branch)
            current_choice_state = self._push_choice_state_stack(state)

        logging.debug("Setting _in_choice_body=true")
        self._in_choice_body = True
        for item in node.body:
            self.visit(item)
        logging.debug("Setting _in_choice_body=false")
        self._in_choice_body = False

        if (len(node.orelse) > 0
                and isinstance(node.orelse[0], ast.If)) or len(
                    node.orelse) == 0:
            logging.debug("More elif conditions to parse")
            # Append the same state object so we can use it for the next branch
            # of the choice
            current_choice_state = self._push_choice_state_stack(
                current_choice_state)
        else:
            logging.debug("Parsing final elif choice branch")
            self._set_current_state(current_choice_state)

        logging.debug("Parsing else choice branch")
        self._in_else = True
        assert_supported_operation(
            len(node.orelse) <= 1,
            "A maximum of 1 state can be included in an `else` clause",
            node,
        )
        for item in node.orelse:
            self.visit(item)
        self._in_else = False

        if isinstance(current_choice_state, ChoiceState):
            # This means we're done building the Choice state because otherwise, we'd
            # be adding things to a ChoiceStateBranch. Point the current state back to
            # the Choice object so that it can point to the Next state.
            self._set_current_state(current_choice_state)
    def visit_Expr(self, node: Any) -> None:
        """Visit expression nodes.

        Expressions are function calls that aren't assigned to a variable. These include:
        * ``update()`` for Pass states
        * ``parallel()`` for Parallel states
        * ``wait`` for Wait states
        * ``map`` for Map states
        * Instantiating Task classes

        Examples::

            data.update({"hello": "world"})
            parallel(branch1, branch2)
            wait(seconds=10)
            map(data["items"], item_iterator)
            Foo()

        """
        logging.debug("Visiting an Expr")
        if isinstance(node.value, ast.Str):
            # This is probably a docstring
            return

        if isinstance(node.value.func, ast.Attribute):
            assert_supported_operation(
                node.value.func.value.id == "data"
                and node.value.func.attr == "update",
                "The only supported method call is `data.update()` to set values on the input data",
                node,
            )
            state = PassState(self.state_graph,
                              f"Pass-{hash_node(node, self.name)}", node)
            self._add_state(state)
            self._set_current_state(state)
        elif node.value.func.id == "parallel":
            assert_supported_operation(
                len(node.value.args) > 0,
                "At least one branch function must be provided to the parallel state.",
                node,
            )
            state = ParallelState(self.state_graph,
                                  f"Parallel-{hash_node(node)}", node)
            self._add_state(state)
            self._set_current_state(state)
            for arg in node.value.args:
                assert_supported_operation(
                    isinstance(arg, ast.Name)
                    and arg.id in self._state_machine_visitors,
                    "Only defined functions can be provided to the parallel state."
                    f" Available functions: {', '.join(self._other_state_machine_names)}",
                    node,
                )
                state.add_branch(self._state_machine_visitors[arg.id])
                # The referenced state machine is used as a parallel branch so we'll
                # demote it
                self._state_machine_visitors[arg.id].is_first_class = False
        elif node.value.func.id == "wait":
            state = WaitState(self.state_graph, f"Wait-{hash_node(node)}",
                              node)
            self._add_state(state)
            self._set_current_state(state)
        elif node.value.func.id == "map":
            args = node.value.args
            assert_supported_operation(
                len(args) == 2,
                "Map state requires two arguments: a list of items from data and an"
                " iterator function",
                node,
            )
            _, iterator = args
            assert_supported_operation(
                isinstance(iterator, ast.Name)
                and iterator.id in self._state_machine_visitors,
                "Only defined functions can be provided to the map state."
                f" Available functions: {', '.join(self._other_state_machine_names)}",
                node,
            )
            state = MapState(
                self.state_graph,
                f"Map-{hash_node(node)}",
                node,
                self._state_machine_visitors[iterator.id],
            )
            self._add_state(state)
            self._set_current_state(state)
            # The referenced state machine is used as an iterator so we'll demote it
            self._state_machine_visitors[iterator.id].is_first_class = False
            self._state_machine_visitors[iterator.id].is_map_iterator = True
        elif node.value.func.id in self._task_visitors:
            # This node is instantiating a task class
            state = create_task_state(self, node,
                                      self._task_visitors[node.value.func.id])
            self._add_state(state)
            self._set_current_state(state)
        elif node.value.func.id in self._state_machine_visitors:
            # This node is nesting a state machine
            state = create_task_state(
                self, node, self._state_machine_visitors[node.value.func.id])
            self._add_state(state)
            self._set_current_state(state)
        else:
            raise UnsupportedOperation(
                """Supported expressions include:
* ``update()`` for Pass states
* ``parallel()`` for Parallel states
* ``wait`` for Wait states
* Instantiating Task classes
* Nesting state machines""",
                node,
            )
    def visit_Assign(self, node: Any) -> None:
        """Visit assignment nodes.

        Assignments are used to create Task and Pass states.

        Examples::

            data["result"] = Foo()
            data["result"] = map(data["items"], item_iterator)
            data["result"] = {"hello": "world"}

        """
        assert_supported_operation(
            len(node.targets) == 1,
            "Value assignments can only target one variable",
            node,
        )
        source = astor.to_source(node).strip()
        logging.debug(f"Visiting Assign ({source})")
        target = node.targets[0]
        assert_supported_operation(
            isinstance(target, ast.Subscript) and target.value.id == "data",
            "Assignment target must be a key on `data`",
            node,
        )
        if (isinstance(node.value, ast.Call)
                and node.value.func.id in self._task_visitors):
            # This node is instantiating a task class
            state = create_task_state(self, node,
                                      self._task_visitors[node.value.func.id])
            self._add_state(state)
            self._set_current_state(state)
        elif (isinstance(node.value, ast.Call)
              and node.value.func.id in self._state_machine_visitors):
            # This node is nesting a state machine
            state = create_task_state(
                self, node, self._state_machine_visitors[node.value.func.id])
            self._add_state(state)
            self._set_current_state(state)
        elif isinstance(node.value, ast.Call) and node.value.func.id == "map":
            # This node is instantiating a map state
            args = node.value.args
            assert_supported_operation(
                len(args) == 2,
                "Map state requires two arguments: a list of items from data and an"
                " iterator function",
                node,
            )
            _, iterator = args
            assert_supported_operation(
                isinstance(iterator, ast.Name)
                and iterator.id in self._state_machine_visitors,
                "Only defined functions can be provided to the map state."
                f" Available functions: {', '.join(self._other_state_machine_names)}",
                node,
            )
            state = MapState(
                self.state_graph,
                f"Map-{hash_node(node)}",
                node,
                self._state_machine_visitors[iterator.id],
            )
            self._add_state(state)
            self._set_current_state(state)
            # The referenced state machine is used as an iterator so we'll demote it
            self._state_machine_visitors[iterator.id].is_first_class = False
            self._state_machine_visitors[iterator.id].is_map_iterator = True
        else:
            # This node is setting static data
            state = PassState(self.state_graph,
                              f"Pass-{hash_node(node, self.name)}", node)
            self._add_state(state)
            self._set_current_state(state)