Ejemplo n.º 1
0
 def __init__(self, *args, **kwargs):
     if kwargs:
         # The call to list is needed for Python 3
         assert list(kwargs.keys()) == ["variable"]
         error_msg = get_variable_trace_string(kwargs["variable"])
         if error_msg:
             args = args + (error_msg, )
     s = "\n".join(args)  # Needed to have the new line print correctly
     super().__init__(s)
Ejemplo n.º 2
0
def compute_test_value(node: Apply):
    """Computes the test value of a node.

    Parameters
    ----------
    node : Apply
        The `Apply` node for which the test value is computed.

    Returns
    -------
    None
        The `tag.test_value`s are updated in each `Variable` in `node.outputs`.

    """
    # Gather the test values for each input of the node
    storage_map = {}
    compute_map = {}
    for i, ins in enumerate(node.inputs):
        try:
            storage_map[ins] = [ins.get_test_value()]
            compute_map[ins] = [True]
        except TestValueError:
            # no test-value was specified, act accordingly
            if config.compute_test_value == "warn":
                warnings.warn(
                    f"Warning, Cannot compute test value: input {i} ({ins}) of Op {node} missing default value",
                    stacklevel=2,
                )
                return
            elif config.compute_test_value == "raise":
                detailed_err_msg = get_variable_trace_string(ins)

                raise ValueError(
                    f"Cannot compute test value: input {i} ({ins}) of Op {node} missing default value. {detailed_err_msg}"
                )
            elif config.compute_test_value == "ignore":
                return
            elif config.compute_test_value == "pdb":
                import pdb

                pdb.post_mortem(sys.exc_info()[2])
            else:
                raise ValueError(
                    f"{config.compute_test_value} is invalid for option config.compute_test_value"
                )

    # All inputs have test-values; perform the `Op`'s computation

    # The original values should not be destroyed, so we copy the values of the
    # inputs in `destroy_map`
    destroyed_inputs_idx = set()
    if node.op.destroy_map:
        for i_pos_list in node.op.destroy_map.values():
            destroyed_inputs_idx.update(i_pos_list)
    for inp_idx in destroyed_inputs_idx:
        inp = node.inputs[inp_idx]
        storage_map[inp] = [copy.copy(storage_map[inp][0])]

    # Prepare `storage_map` and `compute_map` for the outputs
    for o in node.outputs:
        storage_map[o] = [None]
        compute_map[o] = [False]

    # Create a thunk that performs the computation
    thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
    thunk.inputs = [storage_map[v] for v in node.inputs]
    thunk.outputs = [storage_map[v] for v in node.outputs]

    required = thunk()
    assert not required  # We provided all inputs

    for output in node.outputs:
        # Check that the output has been computed
        assert compute_map[output][0], (output, storage_map[output][0])

        # Add 'test_value' to output tag, so that downstream `Op`s can use
        # these numerical values as test values
        output.tag.test_value = storage_map[output][0]