def test_classes_file_path():
    tmp_file = tempfile.NamedTemporaryFile()
    classes_path = f'{tmp_file.name}.txt'
    train_pipeline = [dict(type='LoadImageFromFile')]
    kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)

    # classes.txt with full categories
    categories = get_classes('cityscapes')
    with open(classes_path, 'w') as f:
        f.write('\n'.join(categories))
    assert list(CityscapesDataset(**kwargs).CLASSES) == categories

    # classes.txt with sub categories
    categories = ['road', 'sidewalk', 'building']
    with open(classes_path, 'w') as f:
        f.write('\n'.join(categories))
    assert list(CityscapesDataset(**kwargs).CLASSES) == categories

    # classes.txt with unknown categories
    categories = ['road', 'sidewalk', 'unknown']
    with open(classes_path, 'w') as f:
        f.write('\n'.join(categories))

    with pytest.raises(ValueError):
        CityscapesDataset(**kwargs)

    tmp_file.close()
    os.remove(classes_path)
    assert not osp.exists(classes_path)
def test_classes():
    assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
    assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
        'pascal_voc')
    assert list(
        ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')

    with pytest.raises(ValueError):
        get_classes('unsupported')
def test_classes():
    assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
    assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
        'pascal_voc')
    assert list(
        ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
    assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff')
    assert list(LoveDADataset.CLASSES) == get_classes('loveda')
    assert list(PotsdamDataset.CLASSES) == get_classes('potsdam')
    assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen')
    assert list(iSAIDDataset.CLASSES) == get_classes('isaid')

    with pytest.raises(ValueError):
        get_classes('unsupported')