示例#1
0
from homura import Registry
SEMI_DATASET_REGISTRY = Registry('semi_dataset')
while True:
    from .datasets import *
    break
示例#2
0
from homura import Registry, get_environ
from .classification import ExtraSVHN, ImageNet, OriginalSVHN
from .detection import VOCDetection, det_collate_fn
from .segmentation import ExtendedVOCSegmentation, seg_collate_fn
from .visionset import VisionSet

DATASET_REGISTRY = Registry('vision_datasets', type=VisionSet)

from torchvision import datasets, transforms
from .. import transforms as homura_transforms

DATASET_REGISTRY.register_from_dict({
    'cifar10':
    VisionSet(datasets.CIFAR10, "~/.torch/data/cifar10", 10, [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ], [
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip()
    ]),
    'cifar100':
    VisionSet(datasets.CIFAR100, "~/.torch/data/cifar100", 100, [
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ], [
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip()
    ]),
    'svhn':
示例#3
0
from homura import Registry

MODEL_REGISTRY = Registry('vision_model')

from torchvision import models

MODEL_REGISTRY.register(models.resnet18)
MODEL_REGISTRY.register(models.resnet50)
MODEL_REGISTRY.register(models.wide_resnet50_2)
MODEL_REGISTRY.register(models.densenet121)
MODEL_REGISTRY.register(models.vgg19_bn)

from .densenet import densenet40, densenet100
from .cifar_resnet import wrn28_2, wrn40_2, wrn28_10, resnet20, resnet56, resnext29_32x4d

from .unet import unet
示例#4
0
from homura import Registry

SEMI_TRAINER_REGISTRY = Registry('semi_trainer')

while True:
    from .utils import unroll
    from .Ladder import Ladder
    from .MeanTeacher import MeanTeacher
    from .InterpolationConsistency import InterpolationConsistency
    from .AdversariallyLearnedInference import AdversariallyLearnedInference
    from .MixMatch import MixMatch
    break
示例#5
0
from homura import Registry

SCHEDULER_REGISTRY = Registry('scheduler')

while True:
    from .scheduler import Linear
    break
示例#6
0
from homura import Registry
TRANSFORM_REGISTRY = Registry('transform')
while True:
    from .transform_many_times import *
    from .randaugment import RandAugment
    break
示例#7
0
from homura import Registry

MODEL_REGISTRY = Registry('vision_model')

from torchvision import models

MODEL_REGISTRY.register(models.resnet18)
MODEL_REGISTRY.register(models.resnet50)
MODEL_REGISTRY.register(models.resnet101)
MODEL_REGISTRY.register(models.resnet152)
MODEL_REGISTRY.register(models.wide_resnet50_2)
MODEL_REGISTRY.register(models.densenet121)
MODEL_REGISTRY.register(models.vgg19_bn)
if hasattr(models, "efficientnet_b0"):
    # >=1.10
    for v in range(8):
        MODEL_REGISTRY.register(getattr(models, f"efficientnet_b{v}"))

from .densenet import densenet40, densenet100
from .cifar_resnet import wrn28_2, wrn40_2, wrn28_10, resnet20, resnet56, resnext29_32x4d

from .unet import unet
示例#8
0
        self.ln = nn.LayerNorm(emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes)
        self._init_weights()

    def _init_weights(self):
        nn.init.zeros_(self.fc.weight)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = self.patch_emb(input)  # BxNxC
        x = self.blocks(x)
        return self.fc(self.ln(x).mean(dim=1))


from homura import Registry

MLPMixers = Registry("MLPMixer")


@MLPMixers.register
def mixer_s32(num_classes, **kwargs) -> MLPMixer:
    return MLPMixer(num_classes, 512, 256, 2048, 32, 8, **kwargs)


@MLPMixers.register
def mixer_s16(num_classes, **kwargs) -> MLPMixer:
    return MLPMixer(num_classes, 512, 256, 2048, 16, 8, **kwargs)


@MLPMixers.register
def mixer_b32(num_classes, **kwargs) -> MLPMixer:
    return MLPMixer(num_classes, 768, 384, 3072, 32, 12, **kwargs)