예제 #1
0
 def _wrapper(fn):
     if fn_or_type is not None:
         type_ = fn_or_type
     else:
         type_hints = get_type_hints(fn)
         if "return" not in type_hints:
             raise RuntimeError(
                 f"Need to either explicitly pass a type to `register`, or use "
                 f"a return type annotation (e.g. `-> Foo:`) on the function!"
             )
         type_ = type_hints["return"]
     register_decoding_fn(type_, fn)
     return fn
예제 #2
0

@dataclass
class Student(Person):
    domain: str = "Computer Science"
    average_grade: float = 0.80


@encode.register
def encode_tensor(obj: Tensor) -> List:
    """ We choose to encode a tensor as a list, for instance """
    return obj.tolist()


# We will use `torch.as_tensor` as our decoding function
register_decoding_fn(Tensor, torch.as_tensor)

# Serialization:
# We can dump to yaml or json:
charlie = Person(name="Charlie")
print(charlie.dumps_yaml())
expected += """\
age: 20
name: Charlie
t:
- 0
- 1
- 2
- 3

"""
예제 #3
0
    @property
    def batch_size(self) -> int:
        return self.hp.batch_size

    @batch_size.setter
    def batch_size(self, value: int) -> None:
        self.hp.batch_size = value 
    
    @property
    def learning_rate(self) -> float:
        return self.hp.learning_rate

    @learning_rate.setter
    def learning_rate(self, value: float) -> None:
        self.hp.learning_rate = value

    def on_task_switch(self, task_id: Optional[int]) -> None:
        """Called when switching between tasks.
        
        Args:
            task_id (Optional[int]): the Id of the task.
        """

    def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
        model_summary = ModelSummary(self, mode=mode)
        log.debug('\n' + str(model_summary))
        return model_summary

from simple_parsing.helpers.serialization import register_decoding_fn
register_decoding_fn(Type[OutputHead], lambda v: v)
예제 #4
0
import numpy as np
import torch
from torch import Tensor, nn

from simple_parsing.helpers import Serializable as SerializableBase
from simple_parsing.helpers import SimpleJsonEncoder, encode
from simple_parsing.helpers.serialization import encode, register_decoding_fn

from .generic_functions.detach import detach
from .generic_functions.move import move
from .encode import encode
from .logging_utils import get_logger
from .utils import dict_union
from sequoia.utils.generic_functions import detach

register_decoding_fn(torch.device, torch.device)

T = TypeVar("T")
logger = get_logger(__file__)


def cpu(x: Any) -> Any:
    return move(x, "cpu")


class Pickleable():
    """ Helps make a class pickleable. """
    def __getstate__(self):
        """ We implement this to just make sure to detach the tensors if any
        before pickling.
        """
예제 #5
0
from sequoia.common.hparams import log_uniform, uniform
from sequoia.settings.rl import ContinualRLSetting
from sequoia.utils.logging_utils import get_logger

from .base import SB3BaseHParams, StableBaselines3Method

logger = get_logger(__file__)


def decode_trainfreq(v: Any):
    if isinstance(v, list) and len(v) == 2:
        return TrainFreq(v[0], v[1])
    return v


register_decoding_fn(TrainFreq, decode_trainfreq)


class OffPolicyModel(OffPolicyAlgorithm, ABC):
    """ Tweaked version of the OffPolicyAlgorithm from SB3. """
    @dataclass
    class HParams(SB3BaseHParams):
        """ Hyper-parameters common to all off-policy algos from SB3. """

        # The learning rate, it can be a function of the current progress (from
        # 1 to 0)
        learning_rate: Union[float, Callable] = log_uniform(1e-6,
                                                            1e-2,
                                                            default=1e-4)
        # size of the replay buffer
        buffer_size: int = uniform(100, 10_000_000, default=1_000_000)
예제 #6
0
simple-parsing when serializing objects to json or yaml.
"""
import enum
import inspect
from pathlib import Path
from typing import Any, List, Union, Type

import numpy as np
import torch
from torch import Tensor, nn, optim

from simple_parsing.helpers import encode
from simple_parsing.helpers.serialization import register_decoding_fn

# Register functions for decoding Tensor and ndarray fields from json/yaml.
register_decoding_fn(Tensor, torch.as_tensor)
register_decoding_fn(np.ndarray, np.asarray)
register_decoding_fn(Type[nn.Module], lambda v: v)
register_decoding_fn(Type[optim.Optimizer], lambda v: v)


# NOTE: Uncomment this to enable logging tensors as-is when calling to_dict on a
# Serializable dataclass
@encode.register(Tensor)
def no_op_encode(value: Any):
    return value


# TODO: Look deeper into how things are pickled and moved by pytorch-lightning.
# Right now there is a warning by pytorch-lightning saying that some metrics
# will not be included in a checkpoint because they are lists instead of Tensors.
예제 #7
0
    QUATERNION = "quaternion"
    KEYPOINT_OFFSET = "keypoint_offset"
    BOX_2D = "box_2d"
    INTRINSIC_MATRIX = "camera/intrinsics"
    CENTER_2D = "center_2d"
    SCALE_MAP = "scale_map"  # dense per-element scale prediction.
    INDEX = "instance_index"  # (batch_index, instance_index)


@encode.register(Schema)
def encode_schema(obj: Schema) -> str:
    """Encode the enum with the underlying `str` representation."""
    return str(obj.value)


def decode_schema(obj: str) -> Schema:
    """Decode str into Schema enum."""
    return Schema(obj)


register_decoding_fn(Schema, decode_schema)


def main():
    s = encode_schema(Schema.KEYPOINT_2D)
    decode_schema(s)


if __name__ == '__main__':
    main()
예제 #8
0
# NOTE(ycho): Register encoder-decoder pair for `DatasetOptions` enum.
# NOTE(ycho): Parsing from type annotations: only available for python>=3.7.


@encode.register(DatasetOptions)
def encode_dataset_options(obj: DatasetOptions) -> str:
    """Encode the enum with the underlying `str` representation."""
    return str(obj.value)


def decode_dataset_options(obj: str) -> DatasetOptions:
    return DatasetOptions(obj)


register_decoding_fn(DatasetOptions, decode_dataset_options)


@dataclass
class DatasetSettings(Serializable):
    dataset: DatasetOptions = DatasetOptions.OBJECTRON
    cube: ColoredCubeDataset.Settings = ColoredCubeDataset.Settings()
    objectron: ObjectronDetection.Settings = ObjectronDetection.Settings()
    cache: CachedDataset.Settings = CachedDataset.Settings()
    use_cached_dataset: bool = False
    shuffle: bool = False
    num_workers: int = 0


def get_dataset(opts: DatasetSettings,
                train: bool,
예제 #9
0
        num_classes = max(num_classes, y_t_prob.shape[-1])

        y_t_logits = np.zeros([y_t_prob.shape[0], num_classes],
                              dtype=y_t_prob.dtype)

        for i, logits in enumerate(y_t_prob):
            for label, logit in zip(classes, logits):
                y_t_logits[i][label - 1] = logit

    ## We were constructing this to reorder the classes in case the ordering was
    ## not the same between the KNN's internal `classes_` attribute and the task
    ## classes, However I'm not sure if this is necessary anymore.

    # y_t_logits = np.zeros((y_t.size, y_t.max() + 1))
    # for i, label in enumerate(classes):
    #     y_t_logits[:, label] = y_t_prob[:, i]

    # We get the Negative Cross Entropy using the scikit-learn function, but we
    # could instead get it using pytorch's function (maybe even inside the
    # Loss object!
    nce_t = log_loss(y_true=y_t, y_pred=y_t_prob, labels=classes)
    # BUG: There is sometimes a case where some classes aren't present in
    # `classes_`, and as such the ClassificationMetrics object created in the
    # Loss constructor has an error.
    test_loss = Loss(loss_name, loss=nce_t, y_pred=y_t_logits, y=y_t)
    return test_loss


from simple_parsing.helpers.serialization import register_decoding_fn
register_decoding_fn(KnnCallback, lambda v: v)