def test_record_bad(): # Tests that when we record a sequence of events, then # do something different on playback, the Record class catches it. # Record a sequence of events output = StringIO() recorder = Record(file_object=output, replay=False) num_lines = 10 for i in range(num_lines): recorder.handle_line(str(i) + "\n") # Make sure that the playback functionality doesn't raise any errors # when we repeat some of them output_value = output.getvalue() output = StringIO(output_value) playback_checker = Record(file_object=output, replay=True) for i in range(num_lines // 2): playback_checker.handle_line(str(i) + "\n") # Make sure it raises an error when we deviate from the recorded sequence try: playback_checker.handle_line("0\n") except MismatchError: return raise AssertionError("Failed to detect mismatch between recorded sequence " " and repetition of it.")
def test_record_good(): # Tests that when we record a sequence of events, then # repeat it exactly, the Record class: # 1) Records it correctly # 2) Does not raise any errors # Record a sequence of events output = StringIO() recorder = Record(file_object=output, replay=False) num_lines = 10 for i in range(num_lines): recorder.handle_line(str(i) + "\n") # Make sure they were recorded correctly output_value = output.getvalue() assert output_value == "".join(str(i) + "\n" for i in range(num_lines)) # Make sure that the playback functionality doesn't raise any errors # when we repeat them output = StringIO(output_value) playback_checker = Record(file_object=output, replay=True) for i in range(num_lines): playback_checker.handle_line(str(i) + "\n")
def run(replay, log=None): if not replay: log = StringIO() else: log = StringIO(log) record = Record(replay=replay, file_object=log) disturb_mem() mode = RecordMode(record=record) b = sharedX(np.zeros((2, )), name="b") channels = OrderedDict() disturb_mem() v_max = b.max(axis=0) v_min = b.min(axis=0) v_range = v_max - v_min updates = [] for i, val in enumerate([ v_max.max(), v_max.min(), v_range.max(), ]): disturb_mem() s = sharedX(0.0, name="s_" + str(i)) updates.append((s, val)) for var in aesara.graph.basic.ancestors(update for _, update in updates): if var.name is not None and var.name != "b": if var.name[0] != "s" or len(var.name) != 2: var.name = None for key in channels: updates.append((s, channels[key])) f = aesara.function([], mode=mode, updates=updates, on_unused_input="ignore", name="f") for output in f.maker.fgraph.outputs: mode.record.handle_line(var_descriptor(output) + "\n") disturb_mem() f() mode.record.f.flush() if not replay: return log.getvalue()
def test_record_mode_bad(): # Like test_record_bad, but some events are recorded by the # aesara RecordMode, as is the event that triggers the mismatch # error. # Record a sequence of events output = StringIO() recorder = Record(file_object=output, replay=False) record_mode = RecordMode(recorder) i = iscalar() f = function([i], i, mode=record_mode, name="f") num_lines = 10 for i in range(num_lines): recorder.handle_line(str(i) + "\n") f(i) # Make sure that the playback functionality doesn't raise any errors # when we repeat them output_value = output.getvalue() output = StringIO(output_value) playback_checker = Record(file_object=output, replay=True) playback_mode = RecordMode(playback_checker) i = iscalar() f = function([i], i, mode=playback_mode, name="f") for i in range(num_lines // 2): playback_checker.handle_line(str(i) + "\n") f(i) # Make sure a wrong event causes a MismatchError try: f(0) except MismatchError: return raise AssertionError("Failed to detect a mismatch.")
def test_record_mode_good(): # Like test_record_good, but some events are recorded by the # aesara RecordMode. We don't attempt to check the # exact string value of the record in this case. # Record a sequence of events output = StringIO() recorder = Record(file_object=output, replay=False) record_mode = RecordMode(recorder) i = iscalar() f = function([i], i, mode=record_mode, name="f") num_lines = 10 for i in range(num_lines): recorder.handle_line(str(i) + "\n") f(i) # Make sure that the playback functionality doesn't raise any errors # when we repeat them output_value = output.getvalue() output = StringIO(output_value) playback_checker = Record(file_object=output, replay=True) playback_mode = RecordMode(playback_checker) i = iscalar() f = function([i], i, mode=playback_mode, name="f") for i in range(num_lines): playback_checker.handle_line(str(i) + "\n") f(i)