def _datasets_loader(r: registry.Registry): # TODO add original torchvision and wrapped datasets (with kv dataset[i]) from catalyst_gan import datasets as ext_datasets r.add_from_module(ext_datasets) try: from torchvision import datasets r.add_from_module(datasets, prefix=["tv.", "torchvision."]) # wrapped datasets datasets_to_add = _get_module_classes(datasets) prefixes = ["tv.kv.", "torchvision.keyvalue."] for name, cls in datasets_to_add.items(): cls_w = functools.partial(_TorchvisionDatasetWrapper, base_dataset=cls) r.add(**{f"{p}{name}": cls_w for p in prefixes}) except ImportError: pass
def _models_loader(r: registry.Registry): from catalyst_gan import models r.add_from_module(models) try: from torch_mimicry.nets import sngan, ssgan, infomax_gan r.add_from_module(sngan, prefix=['tm.', 'tm.sngan.']) r.add_from_module(ssgan, prefix=['tm.', 'tm.ssgan.']) r.add_from_module(infomax_gan, prefix=['tm.', 'tm.infomax_gan.']) except ImportError: pass # TODO: warning?
def _modules_loader(r: registry.Registry): from catalyst_gan import nn r.add_from_module(nn)
def _metrics_loader(r: registry.Registry): from catalyst_gan import metrics r.add_from_module(metrics)
def _batch_transforms_loader(r: registry.Registry): from catalyst_gan import batch_transforms r.add_from_module(batch_transforms)
def _callbacks_loader(r: registry.Registry): from catalyst_gan import callbacks r.add_from_module(callbacks)
def _criterions_loader(r: registry.Registry): from catalyst_gan.nn import criterion r.add_from_module(criterion)