def test_pipe_stagequery():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
    assert sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
    assert not sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert not sched.is_first_stage
    assert sched.is_last_stage
def test_pipe_train_schedule_singlestage():
    sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    print()
    for idx, cmds in enumerate(full):
        print(idx, cmds)
示例#3
0
def test_pipe_schedule_laststage():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert len(list(
        iter(sched))) == 2 * (sched.micro_batches + sched.stages - 1)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendActivation
                   for instr in cmds)
        assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)
def test_pipe_schedule_firststage():
    sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
        assert all(instr.__class__ != schedule.RecvActivation
                   for instr in cmds)
        for instr in cmds:
            if isinstance(instr, schedule.BufferOpInstruction):
                assert 0 <= instr.buffer_id < sched.num_pipe_buffers()
示例#5
0
def test_pipe_train_schedule_singlestage():
    sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if (idx % 2) != 0:
            assert (len(cmds) == 1) or (len(cmds) == 4)
            assert type(cmds[0]) == schedule.BackwardPass
        else:
            assert len(cmds) == 2
            assert type(cmds[0]) == schedule.LoadMicroBatch
            assert type(cmds[1]) == schedule.ForwardPass
            assert cmds[0].buffer_id == cmds[1].buffer_id
    assert len(full) == sched.num_micro_batches * 2