def build_model(config): """Builds a ClassyModel from a config. This assumes a 'name' key in the config which is used to determine what model class to instantiate. For instance, a config `{"name": "my_model", "foo": "bar"}` will find a class that was registered as "my_model" (see :func:`register_model`) and call .from_config on it.""" assert config["name"] in MODEL_REGISTRY, "unknown model" model = MODEL_REGISTRY[config["name"]].from_config(config) if "heads" in config: heads = defaultdict(list) for head_config in config["heads"]: assert "fork_block" in head_config, "Expect fork_block in config" fork_block = head_config["fork_block"] updated_config = copy.deepcopy(head_config) del updated_config["fork_block"] head = build_head(updated_config) heads[fork_block].append(head) model.set_heads(heads) log_class_usage("Model", model.__class__) return model
def __init__( self, dataset: Sequence, batchsize_per_replica: int, shuffle: bool, transform: Optional[Union[ClassyTransform, Callable]], num_samples: Optional[int], ) -> None: """ Constructor for a ClassyDataset. Args: batchsize_per_replica: Positive integer indicating batch size for each replica shuffle: Whether to shuffle between epochs transform: When set, transform to be applied to each sample num_samples: When set, this restricts the number of samples provided by the dataset """ # Asserts: assert is_pos_int(batchsize_per_replica ), "batchsize_per_replica must be a positive int" assert isinstance(shuffle, bool), "shuffle must be a boolean" assert num_samples is None or is_pos_int( num_samples), "num_samples must be a positive int or None" # Assignments: self.batchsize_per_replica = batchsize_per_replica self.shuffle = shuffle self.transform = transform self.num_samples = num_samples self.dataset = dataset self.num_workers = DEFAULT_NUM_WORKERS log_class_usage("Dataset", self.__class__)
def __init__(self): """Constructor for ClassyModel.""" super().__init__() self._attachable_blocks = {} self._attachable_block_names = [] self._heads = nn.ModuleDict() self._head_outputs = {} log_class_usage("Model", self.__class__)
def __init__(self, update_interval: UpdateInterval): """ Constructor for ClassyParamScheduler Args: update_interval: Specifies the frequency of the param updates """ self.update_interval = update_interval log_class_usage("ParamScheduler", self.__class__)
def build_task(config): """Builds a ClassyTask from a config. This assumes a 'name' key in the config which is used to determine what task class to instantiate. For instance, a config `{"name": "my_task", "foo": "bar"}` will find a class that was registered as "my_task" (see :func:`register_task`) and call .from_config on it.""" task = TASK_REGISTRY[config["name"]].from_config(config) log_class_usage("Task", task.__class__) return task
def __init__(self) -> None: """Constructor for ClassyOptimizer. :var options_view: provides convenient access to current values of learning rate, momentum etc. :var _param_group_schedulers: list of dictionaries in the param_groups format, containing all ParamScheduler instances needed. Constant values are converted to ConstantParamScheduler before being inserted here. """ self.options_view = OptionsView(self) self.optimizer = None self._param_group_schedulers = None log_class_usage("Optimizer", self.__class__)
def build_transform(transform_config: Dict[str, Any]) -> Callable: """Builds a :class:`ClassyTransform` from a config. This assumes a 'name' key in the config which is used to determine what transform class to instantiate. For instance, a config `{"name": "my_transform", "foo": "bar"}` will find a class that was registered as "my_transform" (see :func:`register_transform`) and call .from_config on it. In addition to transforms registered with :func:`register_transform`, we also support instantiating transforms available in the `torchvision.transforms <https://pytorch.org/docs/stable/torchvision/ transforms.html>`_ module. Any keys in the config will get expanded to parameters of the transform constructor. For instance, the following call will instantiate a :class:`torchvision.transforms.CenterCrop`: .. code-block:: python build_transform({"name": "CenterCrop", "size": 224}) """ assert ( "name" in transform_config ), f"name not provided for transform: {transform_config}" name = transform_config["name"] transform_args = {k: v for k, v in transform_config.items() if k != "name"} if name in TRANSFORM_REGISTRY: transform = TRANSFORM_REGISTRY[name].from_config(transform_args) else: # the name should be available in torchvision.transforms # if users specify the torchvision transform name in snake case, # we need to convert it to title case. if not (hasattr(transforms, name) or hasattr(transforms_video, name)): name = name.title().replace("_", "") assert hasattr(transforms, name) or hasattr(transforms_video, name), ( f"{name} isn't a registered tranform" ", nor is it available in torchvision.transforms" ) if hasattr(transforms, name): transform = getattr(transforms, name)(**transform_args) else: transform = getattr(transforms_video, name)(**transform_args) log_class_usage("Transform", transform.__class__) return transform
def build_loss(config): """Builds a ClassyLoss from a config. This assumes a 'name' key in the config which is used to determine what model class to instantiate. For instance, a config `{"name": "my_loss", "foo": "bar"}` will find a class that was registered as "my_loss" (see :func:`register_loss`) and call .from_config on it. In addition to losses registered with :func:`register_loss`, we also support instantiating losses available in the `torch.nn.modules.loss <https: //pytorch.org/docs/stable/nn.html#loss-functions>`_ module. Any keys in the config will get expanded to parameters of the loss constructor. For instance, the following call will instantiate a `torch.nn.modules.CrossEntropyLoss <https://pytorch.org/docs/stable/ nn.html#torch.nn.CrossEntropyLoss>`_: .. code-block:: python build_loss({"name": "CrossEntropyLoss", "reduction": "sum"}) """ assert "name" in config, f"name not provided for loss: {config}" name = config["name"] args = copy.deepcopy(config) del args["name"] if "weight" in args and args["weight"] is not None: # if we are passing weights, we need to change the weights from a list # to a tensor args["weight"] = torch.tensor(args["weight"], dtype=torch.float) if name in LOSS_REGISTRY: loss = LOSS_REGISTRY[name].from_config(config) else: # the name should be available in torch.nn.modules.loss assert hasattr( torch_losses, name), (f"{name} isn't a registered loss" ", nor is it available in torch.nn.modules.loss") loss = getattr(torch_losses, name)(**args) log_class_usage("Loss", loss.__class__) return loss
def __init__(self): log_class_usage("Hooks", self.__class__) self.state = ClassyHookState()
def __init__(self): log_class_usage("Meter", self.__class__)
def __init__(self) -> "ClassyTask": """ Constructs a ClassyTask. """ self.hooks = [] log_class_usage("Task", self.__class__)