head class to instantiate. For instance, a config `{"name": "my_head", "foo": "bar"}` will find a class that was registered as "my_head" (see :func:`register_head`) and call .from_config on it.""" assert "name" in config, "Expect name in config" assert "unique_id" in config, "Expect a global unique id in config" assert config["name"] in HEAD_REGISTRY, "unknown head {}".format( config["name"]) name = config["name"] head_config = copy.deepcopy(config) del head_config["name"] return HEAD_REGISTRY[name].from_config(head_config) # automatically import any Python files in the heads/ directory import_all_modules(FILE_ROOT, "classy_vision.heads") from .fully_connected_head import FullyConnectedHead # isort:skip from .fully_convolutional_linear_head import FullyConvolutionalLinearHead # isort:skip from .identity_head import IdentityHead # isort:skip from .vision_transformer_head import VisionTransformerHead # isort:skip __all__ = [ "ClassyHead", "FullyConnectedHead", "FullyConvolutionalLinearHead", "IdentityHead", "VisionTransformerHead", "build_head", "register_head", ]
return cls return register_model_head_cls def get_model_head(name: str): """ Given the model head name, construct the head if it's registered with VISSL. """ assert name in MODEL_HEADS_REGISTRY, "Unknown model head" return MODEL_HEADS_REGISTRY[name] # automatically import any Python files in the heads/ directory import_all_modules(FILE_ROOT, "vissl.models.heads") from vissl.models.heads.linear_eval_mlp import LinearEvalMLP # isort:skip # noqa from vissl.models.heads.mlp import MLP # isort:skip # noqa from vissl.models.heads.siamese_concat_view import ( # isort:skip # noqa SiameseConcatView, ) from vissl.models.heads.swav_prototypes_head import ( # isort:skip # noqa SwAVPrototypesHead, ) __all__ = [ "get_model_head", "LinearEvalMLP", "MLP", "SiameseConcatView", "SwAVPrototypesHead", ]
def register_loss_cls(cls): if name in LOSS_REGISTRY: raise ValueError( "Cannot register duplicate optimizer ({})".format(name)) if not issubclass(cls, ClassyLoss): raise ValueError("Loss ({}: {}) must extend ClassyLoss".format( name, cls.__name__)) LOSS_REGISTRY[name] = cls LOSS_CLASS_NAMES.add(cls.__name__) return cls return register_loss_cls # automatically import any Python files in the losses/ directory import_all_modules(FILE_ROOT, "classy_vision.losses") from .barron_loss import BarronLoss # isort:skip from .label_smoothing_loss import LabelSmoothingCrossEntropyLoss # isort:skip from .multi_output_sum_loss import MultiOutputSumLoss # isort:skip from .soft_target_cross_entropy_loss import SoftTargetCrossEntropyLoss # isort:skip from .sum_arbitrary_loss import SumArbitraryLoss # isort:skip __all__ = [ "BarronLoss", "ClassyLoss", "LabelSmoothingCrossEntropyLoss", "MultiOutputSumLoss", "SoftTargetCrossEntropyLoss", "SumArbitraryLoss", "build_loss",
) if cls.__name__ in OPTIMIZER_CLASS_NAMES: raise ValueError( "Cannot register optimizer with duplicate class name({})".format( cls.__name__ ) ) OPTIMIZER_REGISTRY[name] = cls OPTIMIZER_CLASS_NAMES.add(cls.__name__) return cls return register_optimizer_cls # automatically import any Python files in the optim/ directory import_all_modules(FILE_ROOT, "classy_vision.optim") from .adam import Adam # isort:skip from .rmsprop import RMSProp # isort:skip from .rmsprop_tf import RMSPropTF # isort:skip from .sgd import SGD # isort:skip __all__ = [ "Adam", "ClassyOptimizer", "RMSProp", "RMSPropTF", "SGD", "build_optimizer", "register_optimizer", ]
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # Automatically import any Python files in the losses/ directory import_all_modules(FILE_ROOT, "losses")
... To get a model trunk from a configuration file, see :func:`get_model_trunk`.""" def register_model_trunk_cls(cls: Callable[..., Callable]): if name in MODEL_TRUNKS_REGISTRY: raise ValueError( "Cannot register duplicate model trunk ({})".format(name)) if cls.__name__ in MODEL_TRUNKS_NAMES: raise ValueError( "Cannot register task with duplicate model trunk name ({})". format(cls.__name__)) MODEL_TRUNKS_REGISTRY[name] = cls MODEL_TRUNKS_NAMES.add(cls.__name__) return cls return register_model_trunk_cls def get_model_trunk(name: str): """ Given the model trunk name, construct the trunk if it's registered with VISSL. """ assert name in MODEL_TRUNKS_REGISTRY, "Unknown model trunk" return MODEL_TRUNKS_REGISTRY[name] # automatically import any Python files in the trunks/ directory import_all_modules(FILE_ROOT, "vissl.models.trunks")
raise ValueError("Task ({}: {}) must extend ClassyTask".format( name, cls.__name__)) if cls.__name__ in TASK_CLASS_NAMES: msg = ("Cannot register task with duplicate class name({})." + "Previously registered at \n{}\n") raise ValueError( msg.format(cls.__name__, TASK_CLASS_NAMES_TB[cls.__name__])) tb = "".join(traceback.format_stack()) TASK_REGISTRY[name] = cls TASK_CLASS_NAMES.add(cls.__name__) TASK_REGISTRY_TB[name] = tb TASK_CLASS_NAMES_TB[cls.__name__] = tb return cls return register_task_cls from .classification_task import ClassificationTask # isort:skip from .fine_tuning_task import FineTuningTask # isort:skip __all__ = [ "ClassyTask", "FineTuningTask", "build_task", "register_task", "ClassificationTask", ] # automatically import any Python files in the tasks/ directory import_all_modules(FILE_ROOT, "classy_vision.tasks")
if hasattr(transforms, name) or hasattr(transforms_video, name): raise ValueError( "{} has existed in torchvision.transforms, Please change the name!".format( name ) ) TRANSFORM_REGISTRY[name] = cls tb = "".join(traceback.format_stack()) TRANSFORM_REGISTRY_TB[name] = tb return cls return register_transform_cls # automatically import any Python files in the transforms/ directory import_all_modules(FILE_ROOT, "classy_vision.dataset.transforms") from .lighting_transform import LightingTransform # isort:skip from .util import ApplyTransformToKey # isort:skip from .util import ImagenetAugmentTransform # isort:skip from .util import ImagenetAugmentTransform # isort:skip from .util import ImagenetNoAugmentTransform # isort:skip from .util import GenericImageTransform # isort:skip from .util import TupleToMapTransform # isort:skip __all__ = [ "ClassyTransform", "ImagenetAugmentTransform", "ImagenetNoAugmentTransform", "GenericImageTransform",
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # automatically import any Python files in the meters/ directory import_all_modules(FILE_ROOT, "vissl.meters")
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # automatically import any Python files in the param_scheduler/ directory import_all_modules(FILE_ROOT, "vissl.optimizers.param_scheduler")
"Previously registered at \n{}\n") raise ValueError( msg.format(cls.__name__, DATASET_CLASS_NAMES_TB[cls.__name__])) tb = "".join(traceback.format_stack()) DATASET_REGISTRY[name] = cls DATASET_CLASS_NAMES.add(cls.__name__) DATASET_REGISTRY_TB[name] = tb DATASET_CLASS_NAMES_TB[cls.__name__] = tb return cls return register_dataset_cls # automatically import any Python files in the dataset/ directory import_all_modules(FILE_ROOT, "classy_vision.dataset") from .classy_cifar import CIFARDataset # isort:skip from .classy_hmdb51 import HMDB51Dataset # isort:skip from .classy_kinetics400 import Kinetics400Dataset # isort:skip from .classy_synthetic_image import SyntheticImageDataset # isort:skip from .classy_synthetic_image_streaming import ( # isort:skip SyntheticImageStreamingDataset, # isort:skip ) # isort:skip from .classy_synthetic_video import SyntheticVideoDataset # isort:skip from .classy_ucf101 import UCF101Dataset # isort:skip from .classy_video_dataset import ClassyVideoDataset # isort:skip from .dataloader_async_gpu_wrapper import DataloaderAsyncGPUWrapper # isort:skip from .dataloader_limit_wrapper import DataloaderLimitWrapper # isort:skip from .dataloader_skip_none_wrapper import DataloaderSkipNoneWrapper # isort:skip from .dataloader_wrapper import DataloaderWrapper # isort:skip
def register_param_scheduler_cls(cls): if name in PARAM_SCHEDULER_REGISTRY: raise ValueError( "Cannot register duplicate param scheduler ({})".format(name)) if not issubclass(cls, ParamScheduler): raise ValueError( "Param Scheduler ({}: {}) must extend ParamScheduler".format( name, cls.__name__)) PARAM_SCHEDULER_REGISTRY[name] = cls return cls return register_param_scheduler_cls # automatically import any Python files in the optim/param_scheduler/ directory import_all_modules(FILE_ROOT, "classy_vision.optim.param_scheduler") from .composite_scheduler import CompositeParamScheduler, IntervalScaling # isort:skip from .fvcore_schedulers import ( ConstantParamScheduler, CosineParamScheduler, LinearParamScheduler, MultiStepParamScheduler, PolynomialDecayParamScheduler, StepParamScheduler, StepWithFixedGammaParamScheduler, ) # isort:skip __all__ = [ "ParamScheduler", "ClassyParamScheduler",
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # automatically import any Python files in the losses/ directory import_all_modules(FILE_ROOT, "vissl.losses")
# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path import torchvision.transforms as pth_transforms from classy_vision.generic.registry_utils import import_all_modules from vissl.data.ssl_transforms.ssl_transforms_wrapper import SSLTransformsWrapper def get_transform(input_transforms_list): """ Given the list of user specified transforms, return the torchvision.transforms.Compose() version of the transforms. Each transform in the composition is SSLTransformsWrapper which wraps the original transforms to handle multi-modal nature of input. """ output_transforms = [] for transform_config in input_transforms_list: transform = SSLTransformsWrapper.from_config(transform_config) output_transforms.append(transform) return pth_transforms.Compose(output_transforms) FILE_ROOT = Path(__file__).parent import_all_modules(FILE_ROOT, "vissl.data.ssl_transforms") __all__ = ["get_transform"]
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # Automatically import any Python files in the datasets/ directory import_all_modules(FILE_ROOT, "datasets")
def register_collator_fn(func): if name in COLLATOR_REGISTRY: raise ValueError( "Cannot register duplicate collator ({})".format(name)) if func.__name__ in COLLATOR_NAMES: raise ValueError( "Cannot register task with duplicate collator name ({})". format(func.__name__)) COLLATOR_REGISTRY[name] = func COLLATOR_NAMES.add(func.__name__) return func return register_collator_fn def get_collator(collator_name, collate_params): """ Given the collator name and the collator params, return the collator if registered with VISSL. Also supports pytorch default collators. """ if collator_name == "default_collate": return default_collate else: assert collator_name in COLLATOR_REGISTRY, "Unknown collator" return partial(COLLATOR_REGISTRY[collator_name], **collate_params) # automatically import any Python files in the collators/ directory import_all_modules(FILE_ROOT, "vissl.data.collators")
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path from classy_vision.generic.registry_utils import import_all_modules FILE_ROOT = Path(__file__).parent # Automatically import any Python files in the models/ directory import_all_modules(FILE_ROOT, "models")
if "heads" in config: heads = defaultdict(list) for head_config in config["heads"]: assert "fork_block" in head_config, "Expect fork_block in config" fork_block = head_config["fork_block"] updated_config = copy.deepcopy(head_config) del updated_config["fork_block"] head = build_head(updated_config) heads[fork_block].append(head) model.set_heads(heads) return model # automatically import any Python files in the models/ directory import_all_modules(FILE_ROOT, "classy_vision.models") from .classy_block import ClassyBlock # isort:skip from .classy_model import ( # isort:skip ClassyModelWrapper, # isort:skip ClassyModelHeadExecutorWrapper, # isort:skip ) # isort:skip from .densenet import DenseNet # isort:skip from .efficientnet import EfficientNet # isort:skip from .mlp import MLP # isort:skip from .regnet import RegNet # isort:skip from .resnet import ResNet # isort:skip from .resnext import ResNeXt # isort:skip from .resnext3d import ResNeXt3D # isort:skip from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip
To get a train step from a configuration file, see :func:`get_train_step`. """ def register_train_step_fn(func): if name in TRAIN_STEP_REGISTRY: raise ValueError( "Cannot register duplicate train step ({})".format(name)) if func.__name__ in TRAIN_STEP_NAMES: raise ValueError( "Cannot register task with duplicate train step name ({})". format(func.__name__)) TRAIN_STEP_REGISTRY[name] = func TRAIN_STEP_NAMES.add(func.__name__) return func return register_train_step_fn def get_train_step(train_step_name: str): """ Lookup the train_step_name in the train step registry and return. If the train step is not implemented, asserts will be thrown and workflow will exit. """ assert train_step_name in TRAIN_STEP_REGISTRY, "Unknown train step" return TRAIN_STEP_REGISTRY[train_step_name] # automatically import any Python files in the train_steps/ directory import_all_modules(FILE_ROOT, "vissl.trainer.train_steps")