def test_pipe_stagequery(): sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0) assert sched.is_first_stage assert not sched.is_last_stage sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1) assert not sched.is_first_stage assert not sched.is_last_stage sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2) assert not sched.is_first_stage assert sched.is_last_stage
def test_pipe_train_schedule_singlestage(): sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0) assert sched.num_micro_batches == 4 full = list(iter(sched)) print() for idx, cmds in enumerate(full): print(idx, cmds)
def test_pipe_schedule_laststage(): sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2) assert len(list( iter(sched))) == 2 * (sched.micro_batches + sched.stages - 1) for cmds in sched: assert all(instr.__class__ != schedule.SendActivation for instr in cmds) assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)
def test_pipe_schedule_firststage(): sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0) for cmds in sched: assert all(instr.__class__ != schedule.SendGrad for instr in cmds) assert all(instr.__class__ != schedule.RecvActivation for instr in cmds) for instr in cmds: if isinstance(instr, schedule.BufferOpInstruction): assert 0 <= instr.buffer_id < sched.num_pipe_buffers()
def test_pipe_train_schedule_singlestage(): sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0) assert sched.num_micro_batches == 4 full = list(iter(sched)) for idx, cmds in enumerate(full): if (idx % 2) != 0: assert (len(cmds) == 1) or (len(cmds) == 4) assert type(cmds[0]) == schedule.BackwardPass else: assert len(cmds) == 2 assert type(cmds[0]) == schedule.LoadMicroBatch assert type(cmds[1]) == schedule.ForwardPass assert cmds[0].buffer_id == cmds[1].buffer_id assert len(full) == sched.num_micro_batches * 2