示例#1
0
def test_record_bad():
    # Tests that when we record a sequence of events, then
    # do something different on playback, the Record class catches it.

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    num_lines = 10

    for i in range(num_lines):
        recorder.handle_line(str(i) + "\n")

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat some of them
    output_value = output.getvalue()
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    for i in range(num_lines // 2):
        playback_checker.handle_line(str(i) + "\n")

    # Make sure it raises an error when we deviate from the recorded sequence
    try:
        playback_checker.handle_line("0\n")
    except MismatchError:
        return
    raise AssertionError("Failed to detect mismatch between recorded sequence "
                         " and repetition of it.")
示例#2
0
def test_record_good():
    # Tests that when we record a sequence of events, then
    # repeat it exactly, the Record class:
    #     1) Records it correctly
    #     2) Does not raise any errors

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    num_lines = 10

    for i in range(num_lines):
        recorder.handle_line(str(i) + "\n")

    # Make sure they were recorded correctly
    output_value = output.getvalue()

    assert output_value == "".join(str(i) + "\n" for i in range(num_lines))

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat them
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    for i in range(num_lines):
        playback_checker.handle_line(str(i) + "\n")
示例#3
0
    def run(replay, log=None):

        if not replay:
            log = StringIO()
        else:
            log = StringIO(log)
        record = Record(replay=replay, file_object=log)

        disturb_mem()

        mode = RecordMode(record=record)

        b = sharedX(np.zeros((2, )), name="b")
        channels = OrderedDict()

        disturb_mem()

        v_max = b.max(axis=0)
        v_min = b.min(axis=0)
        v_range = v_max - v_min

        updates = []
        for i, val in enumerate([
                v_max.max(),
                v_max.min(),
                v_range.max(),
        ]):
            disturb_mem()
            s = sharedX(0.0, name="s_" + str(i))
            updates.append((s, val))

        for var in aesara.graph.basic.ancestors(update
                                                for _, update in updates):
            if var.name is not None and var.name != "b":
                if var.name[0] != "s" or len(var.name) != 2:
                    var.name = None

        for key in channels:
            updates.append((s, channels[key]))
        f = aesara.function([],
                            mode=mode,
                            updates=updates,
                            on_unused_input="ignore",
                            name="f")
        for output in f.maker.fgraph.outputs:
            mode.record.handle_line(var_descriptor(output) + "\n")
        disturb_mem()
        f()

        mode.record.f.flush()

        if not replay:
            return log.getvalue()
示例#4
0
def test_record_mode_bad():
    # Like test_record_bad, but some events are recorded by the
    # aesara RecordMode, as is the event that triggers the mismatch
    # error.

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    record_mode = RecordMode(recorder)

    i = iscalar()
    f = function([i], i, mode=record_mode, name="f")

    num_lines = 10

    for i in range(num_lines):
        recorder.handle_line(str(i) + "\n")
        f(i)

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat them
    output_value = output.getvalue()
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    playback_mode = RecordMode(playback_checker)

    i = iscalar()
    f = function([i], i, mode=playback_mode, name="f")

    for i in range(num_lines // 2):
        playback_checker.handle_line(str(i) + "\n")
        f(i)

    # Make sure a wrong event causes a MismatchError
    try:
        f(0)
    except MismatchError:
        return
    raise AssertionError("Failed to detect a mismatch.")
示例#5
0
def test_record_mode_good():
    # Like test_record_good, but some events are recorded by the
    # aesara RecordMode. We don't attempt to check the
    # exact string value of the record in this case.

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    record_mode = RecordMode(recorder)

    i = iscalar()
    f = function([i], i, mode=record_mode, name="f")

    num_lines = 10

    for i in range(num_lines):
        recorder.handle_line(str(i) + "\n")
        f(i)

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat them
    output_value = output.getvalue()
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    playback_mode = RecordMode(playback_checker)

    i = iscalar()
    f = function([i], i, mode=playback_mode, name="f")

    for i in range(num_lines):
        playback_checker.handle_line(str(i) + "\n")
        f(i)