def test_pydotprint_profile(): A = matrix() prof = aesara.compile.ProfileStats(atexit_print=False, gpu_checks=False) f = aesara.function([A], A + 1, profile=prof) pydotprint(f, print_output_file=False) f([[1]]) pydotprint(f, print_output_file=False)
def test_pydotprint_long_name(): # This is a REALLY PARTIAL TEST. # It prints a graph where there are variable and apply nodes whose long # names are different, but not the shortened names. # We should not merge those nodes in the dot graph. x = dvector() mode = aesara.compile.mode.get_default_mode().excluding("fusion") f = aesara.function([x], [x * 2, x + x], mode=mode) f([1, 2, 3, 4]) pydotprint(f, max_label_size=5, print_output_file=False) pydotprint([x * 2, x + x], max_label_size=5, print_output_file=False)
def test_printing_scan(): def f_pow2(x_tm1): return 2 * x_tm1 state = scalar("state") n_steps = iscalar("nsteps") output, updates = aesara.scan( f_pow2, [], state, [], n_steps=n_steps, truncate_gradient=-1, go_backwards=False ) f = aesara.function( [state, n_steps], output, updates=updates, allow_input_downcast=True ) pydotprint(output, scan_graphs=True) pydotprint(f, scan_graphs=True)
def test_pydotprint_cond_highlight(): # This is a REALLY PARTIAL TEST. # I did them to help debug stuff. x = dvector() f = aesara.function([x], x * 2) f([1, 2, 3, 4]) s = StringIO() new_handler = logging.StreamHandler(s) new_handler.setLevel(logging.DEBUG) orig_handler = aesara.logging_default_handler aesara.aesara_logger.removeHandler(orig_handler) aesara.aesara_logger.addHandler(new_handler) try: pydotprint(f, cond_highlight=True, print_output_file=False) finally: aesara.aesara_logger.addHandler(orig_handler) aesara.aesara_logger.removeHandler(new_handler) assert (s.getvalue() == "pydotprint: cond_highlight is set but there" " is no IfElse node in the graph\n")
def test_pydotprint_return_image(): x = dvector() ret = pydotprint(x * 2, return_image=True) assert isinstance(ret, (str, bytes))