def __init__(self, *args, **kwargs): if kwargs: # The call to list is needed for Python 3 assert list(kwargs.keys()) == ["variable"] error_msg = get_variable_trace_string(kwargs["variable"]) if error_msg: args = args + (error_msg, ) s = "\n".join(args) # Needed to have the new line print correctly super().__init__(s)
def compute_test_value(node: Apply): """Computes the test value of a node. Parameters ---------- node : Apply The `Apply` node for which the test value is computed. Returns ------- None The `tag.test_value`s are updated in each `Variable` in `node.outputs`. """ # Gather the test values for each input of the node storage_map = {} compute_map = {} for i, ins in enumerate(node.inputs): try: storage_map[ins] = [ins.get_test_value()] compute_map[ins] = [True] except TestValueError: # no test-value was specified, act accordingly if config.compute_test_value == "warn": warnings.warn( f"Warning, Cannot compute test value: input {i} ({ins}) of Op {node} missing default value", stacklevel=2, ) return elif config.compute_test_value == "raise": detailed_err_msg = get_variable_trace_string(ins) raise ValueError( f"Cannot compute test value: input {i} ({ins}) of Op {node} missing default value. {detailed_err_msg}" ) elif config.compute_test_value == "ignore": return elif config.compute_test_value == "pdb": import pdb pdb.post_mortem(sys.exc_info()[2]) else: raise ValueError( f"{config.compute_test_value} is invalid for option config.compute_test_value" ) # All inputs have test-values; perform the `Op`'s computation # The original values should not be destroyed, so we copy the values of the # inputs in `destroy_map` destroyed_inputs_idx = set() if node.op.destroy_map: for i_pos_list in node.op.destroy_map.values(): destroyed_inputs_idx.update(i_pos_list) for inp_idx in destroyed_inputs_idx: inp = node.inputs[inp_idx] storage_map[inp] = [copy.copy(storage_map[inp][0])] # Prepare `storage_map` and `compute_map` for the outputs for o in node.outputs: storage_map[o] = [None] compute_map[o] = [False] # Create a thunk that performs the computation thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) thunk.inputs = [storage_map[v] for v in node.inputs] thunk.outputs = [storage_map[v] for v in node.outputs] required = thunk() assert not required # We provided all inputs for output in node.outputs: # Check that the output has been computed assert compute_map[output][0], (output, storage_map[output][0]) # Add 'test_value' to output tag, so that downstream `Op`s can use # these numerical values as test values output.tag.test_value = storage_map[output][0]