Exemplo n.º 1
0
def build_model(config):
    """Builds a ClassyModel from a config.

    This assumes a 'name' key in the config which is used to determine what
    model class to instantiate. For instance, a config `{"name": "my_model",
    "foo": "bar"}` will find a class that was registered as "my_model"
    (see :func:`register_model`) and call .from_config on it."""

    assert config["name"] in MODEL_REGISTRY, "unknown model"
    model = MODEL_REGISTRY[config["name"]].from_config(config)
    if "heads" in config:
        heads = defaultdict(list)
        for head_config in config["heads"]:
            assert "fork_block" in head_config, "Expect fork_block in config"
            fork_block = head_config["fork_block"]
            updated_config = copy.deepcopy(head_config)
            del updated_config["fork_block"]

            head = build_head(updated_config)
            heads[fork_block].append(head)
        model.set_heads(heads)

    log_class_usage("Model", model.__class__)

    return model
Exemplo n.º 2
0
    def __init__(
        self,
        dataset: Sequence,
        batchsize_per_replica: int,
        shuffle: bool,
        transform: Optional[Union[ClassyTransform, Callable]],
        num_samples: Optional[int],
    ) -> None:
        """
        Constructor for a ClassyDataset.

        Args:
            batchsize_per_replica: Positive integer indicating batch size for each
                replica
            shuffle: Whether to shuffle between epochs
            transform: When set, transform to be applied to each sample
            num_samples: When set, this restricts the number of samples provided by
                the dataset
        """
        # Asserts:
        assert is_pos_int(batchsize_per_replica
                          ), "batchsize_per_replica must be a positive int"
        assert isinstance(shuffle, bool), "shuffle must be a boolean"
        assert num_samples is None or is_pos_int(
            num_samples), "num_samples must be a positive int or None"

        # Assignments:
        self.batchsize_per_replica = batchsize_per_replica
        self.shuffle = shuffle
        self.transform = transform
        self.num_samples = num_samples
        self.dataset = dataset
        self.num_workers = DEFAULT_NUM_WORKERS

        log_class_usage("Dataset", self.__class__)
Exemplo n.º 3
0
    def __init__(self):
        """Constructor for ClassyModel."""
        super().__init__()
        self._attachable_blocks = {}
        self._attachable_block_names = []
        self._heads = nn.ModuleDict()
        self._head_outputs = {}

        log_class_usage("Model", self.__class__)
Exemplo n.º 4
0
    def __init__(self, update_interval: UpdateInterval):
        """
        Constructor for ClassyParamScheduler

        Args:
            update_interval: Specifies the frequency of the param updates
        """
        self.update_interval = update_interval
        log_class_usage("ParamScheduler", self.__class__)
Exemplo n.º 5
0
def build_task(config):
    """Builds a ClassyTask from a config.

    This assumes a 'name' key in the config which is used to determine what
    task class to instantiate. For instance, a config `{"name": "my_task",
    "foo": "bar"}` will find a class that was registered as "my_task"
    (see :func:`register_task`) and call .from_config on it."""

    task = TASK_REGISTRY[config["name"]].from_config(config)
    log_class_usage("Task", task.__class__)

    return task
Exemplo n.º 6
0
    def __init__(self) -> None:
        """Constructor for ClassyOptimizer.

        :var options_view: provides convenient access to current values of
            learning rate, momentum etc.
        :var _param_group_schedulers: list of dictionaries in the param_groups
            format, containing all ParamScheduler instances needed. Constant
            values are converted to ConstantParamScheduler before being inserted
            here.
        """
        self.options_view = OptionsView(self)
        self.optimizer = None
        self._param_group_schedulers = None
        log_class_usage("Optimizer", self.__class__)
Exemplo n.º 7
0
def build_transform(transform_config: Dict[str, Any]) -> Callable:
    """Builds a :class:`ClassyTransform` from a config.

    This assumes a 'name' key in the config which is used to determine what
    transform class to instantiate. For instance, a config `{"name":
    "my_transform", "foo": "bar"}` will find a class that was registered as
    "my_transform" (see :func:`register_transform`) and call .from_config on
    it.

    In addition to transforms registered with :func:`register_transform`, we
    also support instantiating transforms available in the
    `torchvision.transforms <https://pytorch.org/docs/stable/torchvision/
    transforms.html>`_ module. Any keys in the config will get expanded
    to parameters of the transform constructor. For instance, the following
    call will instantiate a :class:`torchvision.transforms.CenterCrop`:

    .. code-block:: python

      build_transform({"name": "CenterCrop", "size": 224})
    """
    assert (
        "name" in transform_config
    ), f"name not provided for transform: {transform_config}"
    name = transform_config["name"]

    transform_args = {k: v for k, v in transform_config.items() if k != "name"}

    if name in TRANSFORM_REGISTRY:
        transform = TRANSFORM_REGISTRY[name].from_config(transform_args)
    else:
        # the name should be available in torchvision.transforms
        # if users specify the torchvision transform name in snake case,
        # we need to convert it to title case.
        if not (hasattr(transforms, name) or hasattr(transforms_video, name)):
            name = name.title().replace("_", "")
        assert hasattr(transforms, name) or hasattr(transforms_video, name), (
            f"{name} isn't a registered tranform"
            ", nor is it available in torchvision.transforms"
        )
        if hasattr(transforms, name):
            transform = getattr(transforms, name)(**transform_args)
        else:
            transform = getattr(transforms_video, name)(**transform_args)
    log_class_usage("Transform", transform.__class__)
    return transform
Exemplo n.º 8
0
def build_loss(config):
    """Builds a ClassyLoss from a config.

    This assumes a 'name' key in the config which is used to determine what
    model class to instantiate. For instance, a config `{"name": "my_loss",
    "foo": "bar"}` will find a class that was registered as "my_loss"
    (see :func:`register_loss`) and call .from_config on it.

    In addition to losses registered with :func:`register_loss`, we also
    support instantiating losses available in the `torch.nn.modules.loss <https:
    //pytorch.org/docs/stable/nn.html#loss-functions>`_
    module. Any keys in the config will get expanded to parameters of the loss
    constructor. For instance, the following call will instantiate a
    `torch.nn.modules.CrossEntropyLoss <https://pytorch.org/docs/stable/
    nn.html#torch.nn.CrossEntropyLoss>`_:

    .. code-block:: python

     build_loss({"name": "CrossEntropyLoss", "reduction": "sum"})
    """

    assert "name" in config, f"name not provided for loss: {config}"
    name = config["name"]
    args = copy.deepcopy(config)
    del args["name"]
    if "weight" in args and args["weight"] is not None:
        # if we are passing weights, we need to change the weights from a list
        # to a tensor
        args["weight"] = torch.tensor(args["weight"], dtype=torch.float)
    if name in LOSS_REGISTRY:
        loss = LOSS_REGISTRY[name].from_config(config)
    else:
        # the name should be available in torch.nn.modules.loss
        assert hasattr(
            torch_losses,
            name), (f"{name} isn't a registered loss"
                    ", nor is it available in torch.nn.modules.loss")
        loss = getattr(torch_losses, name)(**args)
    log_class_usage("Loss", loss.__class__)
    return loss
Exemplo n.º 9
0
 def __init__(self):
     log_class_usage("Hooks", self.__class__)
     self.state = ClassyHookState()
Exemplo n.º 10
0
 def __init__(self):
     log_class_usage("Meter", self.__class__)
Exemplo n.º 11
0
 def __init__(self) -> "ClassyTask":
     """
     Constructs a ClassyTask.
     """
     self.hooks = []
     log_class_usage("Task", self.__class__)