示例#1
0
def _models_loader(r: registry.Registry):
    from catalyst_gan import models
    r.add_from_module(models)

    try:
        from torch_mimicry.nets import sngan, ssgan, infomax_gan
        r.add_from_module(sngan, prefix=['tm.', 'tm.sngan.'])
        r.add_from_module(ssgan, prefix=['tm.', 'tm.ssgan.'])
        r.add_from_module(infomax_gan, prefix=['tm.', 'tm.infomax_gan.'])
    except ImportError:
        pass  # TODO: warning?
示例#2
0
def _datasets_loader(r: registry.Registry):
    # TODO add original torchvision and wrapped datasets (with kv dataset[i])
    from catalyst_gan import datasets as ext_datasets
    r.add_from_module(ext_datasets)

    try:
        from torchvision import datasets
        r.add_from_module(datasets, prefix=["tv.", "torchvision."])
        # wrapped datasets
        datasets_to_add = _get_module_classes(datasets)
        prefixes = ["tv.kv.", "torchvision.keyvalue."]
        for name, cls in datasets_to_add.items():
            cls_w = functools.partial(_TorchvisionDatasetWrapper,
                                      base_dataset=cls)
            r.add(**{f"{p}{name}": cls_w for p in prefixes})
    except ImportError:
        pass
示例#3
0
def _modules_loader(r: registry.Registry):
    from catalyst_gan import nn
    r.add_from_module(nn)
示例#4
0
def _metrics_loader(r: registry.Registry):
    from catalyst_gan import metrics
    r.add_from_module(metrics)
示例#5
0
def _batch_transforms_loader(r: registry.Registry):
    from catalyst_gan import batch_transforms
    r.add_from_module(batch_transforms)
示例#6
0
def _callbacks_loader(r: registry.Registry):
    from catalyst_gan import callbacks
    r.add_from_module(callbacks)
示例#7
0
def _criterions_loader(r: registry.Registry):
    from catalyst_gan.nn import criterion
    r.add_from_module(criterion)