def test_clear_stages_with_ref(stage_factory: StageFactoryBase): s = make_stage(stage_factory) s2 = make_stage(stage_factory) s.name = 'stage' s2.name = 'stage2' stage_factory.add_stage(s) stage_factory.add_stage(s2) assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert stage_factory.stage_names['stage2'] is s2 assert s2 in stage_factory.stages ref = stage_factory.get_stage_ref(name='stage') stage_factory.clear_stages() # s should not have been removed, but s2 was removed assert ref.stage is s assert s.has_ref assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert 'stage2' not in stage_factory.stage_names assert s2 not in stage_factory.stages stage_factory.clear_stages(force=True) assert ref.stage is s assert s.has_ref assert 'stage' not in stage_factory.stage_names assert s not in stage_factory.stages stage_factory.return_stage_ref(ref) assert not s.has_ref
def test_stage_ref(stage_factory: StageFactoryBase): s = make_stage(stage_factory) s.name = 'me stage' s2 = make_stage(stage_factory) stage_factory.add_stage(s) ref1 = stage_factory.get_stage_ref(name='me stage') ref2 = stage_factory.get_stage_ref(stage=s2) assert ref1.stage is s assert ref2.stage is s2 assert s.has_ref assert s2.has_ref assert s in stage_factory._stage_ref assert s2 in stage_factory._stage_ref stage_factory.return_stage_ref(ref1) assert ref2.stage is s2 assert not s.has_ref assert s2.has_ref assert s not in stage_factory._stage_ref assert s2 in stage_factory._stage_ref stage_factory.return_stage_ref(ref2) assert not s.has_ref assert not s2.has_ref assert s not in stage_factory._stage_ref assert s2 not in stage_factory._stage_ref
def test_return_not_added_stage_ref(stage_factory: StageFactoryBase): from ceed.stage import CeedStageRef s = make_stage(stage_factory) ref = CeedStageRef( stage_factory=stage_factory, function_factory=stage_factory.function_factory, shape_factory=stage_factory.shape_factory, stage=s) with pytest.raises(ValueError): stage_factory.return_stage_ref(ref)
def test_remove_stage_with_ref(stage_factory: StageFactoryBase): s = make_stage(stage_factory) s2 = make_stage(stage_factory) s3 = make_stage(stage_factory) s.name = 'stage' s2.name = 'stage2' s3.name = 'stage3' stage_factory.add_stage(s) stage_factory.add_stage(s2) stage_factory.add_stage(s3) assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert stage_factory.stage_names['stage2'] is s2 assert s2 in stage_factory.stages assert stage_factory.stage_names['stage3'] is s3 assert s3 in stage_factory.stages ref = stage_factory.get_stage_ref(name='stage') ref3 = stage_factory.get_stage_ref(name='stage3') assert not stage_factory.remove_stage(s) assert ref.stage is s assert s.has_ref assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert stage_factory.stage_names['stage2'] is s2 assert s2 in stage_factory.stages assert stage_factory.stage_names['stage3'] is s3 assert s3 in stage_factory.stages assert stage_factory.remove_stage(s2) assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert 'stage2' not in stage_factory.stage_names assert s2 not in stage_factory.stages assert stage_factory.stage_names['stage3'] is s3 assert s3 in stage_factory.stages assert not stage_factory.remove_stage(s3) assert ref3.stage is s3 assert s3.has_ref assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert stage_factory.stage_names['stage3'] is s3 assert s3 in stage_factory.stages assert stage_factory.remove_stage(s3, force=True) assert ref3.stage is s3 assert s3.has_ref assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages assert 'stage3' not in stage_factory.stage_names assert s3 not in stage_factory.stages assert not stage_factory.remove_stage(s) assert ref.stage is s assert s.has_ref assert stage_factory.stage_names['stage'] is s assert s in stage_factory.stages stage_factory.return_stage_ref(ref) assert not s.has_ref assert stage_factory.remove_stage(s) assert 'stage' not in stage_factory.stage_names assert s not in stage_factory.stages stage_factory.return_stage_ref(ref3) assert not s3.has_ref