예제 #1
0
class Preprocessing:
    """A wrapper around `tf.data` preprocessing."""

    decoders: Optional[Dict[str, Union[tfds.decode.Decoder,
                                       Dict]]] = Field(None)

    # The shape of the processed input. Must match the output of `input()`.
    input_shape: Tuple[int, int, int] = Field()

    def input(self, data, training) -> tf.Tensor:
        """
        A method to define preprocessing for model input. This method or
        `__call__` needs to be overwritten by all subclasses.

        Arguments:
            data:
                A dictionary of type {feature_name: tf.Tensor}.
            training:
                An optional `bool` to indicate whether the data is training
                data.
        Returns:
            A tensor of processed input, with shape `self.input_shape`.
        """

        raise NotImplementedError("Must be implemented in subclasses.")

    def output(self, data, training) -> tf.Tensor:
        """
        A method to define preprocessing for model output. This method or
        `__call__` needs to be overwritten by all subclasses.

        Arguments:
            data:
                A dictionary of type {feature_name: tf.Tensor}.
            training:
                An optional `bool` to indicate whether the data is training
                data.
        Returns:
            A tensor of processed output.
        """

        raise NotImplementedError("Must be implemented in subclasses.")

    def __call__(self, data, training=False) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Apply Preprocessing.

        Arguments:
            data:
                A dictionary of type {feature_name: tf.Tensor}.
            training:
                An optional `bool` to indicate whether the data is training
                data.
        Returns:
            A pair of processed input and output.
        """

        input_fn = pass_training_kwarg(self.input, training=training)
        output_fn = pass_training_kwarg(self.output, training=training)
        return input_fn(data), output_fn(data)
예제 #2
0
    class TestTask:
        a: int = Field()
        b: str = Field("foo")
        c: bool = Field(False)

        def run(self):
            print(self.a, self.b, self.c)
예제 #3
0
def test_allow_missing():
    # This should succeed because we don't set a default value...
    f = Field(allow_missing=True)
    assert f.allow_missing

    # ...but this should fail because there's a default provided
    with pytest.raises(ValueError):
        Field(3.14, allow_missing=True)
    class A:
        a: int = Field()
        b: float = Field(allow_missing=True)

        @Field
        def c(self) -> float:
            if hasattr(self, "b"):
                return self.b
            return self.a
예제 #5
0
def test_unregistered_field():
    field = Field(5)

    with pytest.raises(
            ValueError,
            match="This field has not been registered to a component"):
        field.has_default

    with pytest.raises(
            ValueError,
            match="This field has not been registered to a component"):
        field.get_default(object())
    class Factory:
        child: Child = ComponentField(Child)

        x: int = Field(5)

        def build(self) -> int:
            return self.child.x
예제 #7
0
class Experiment:
    """
    A wrapper around a Keras experiment. Subclasses must implement their
    training loop in `run`.
    """

    # Nested components
    dataset: Dataset = ComponentField()
    preprocessing: Preprocessing = ComponentField()
    model: keras.models.Model = ComponentField()

    # Parameters
    epochs: int = Field()
    batch_size: int = Field()
    loss: Optional[Union[keras.losses.Loss, str]] = Field()
    optimizer: Union[keras.optimizers.Optimizer, str] = Field()
예제 #8
0
    class C:
        x: int = Field(3)

        def __pre_configure__(self, conf):
            assert not self.__component_configured__
            if "x" in conf:
                conf["x"] = conf["x"] * 2
            return conf
 class Child:
     a: int = Field(allow_missing=True)
 class Child1(Base):
     a = Field(5)
 class B:
     b: int = Field()
 class GrandParent:
     c: float = Field(3.14)
     parent: Parent = ComponentField(Parent)
 class Child:
     a: int = Field(1)
예제 #14
0
 class Child:
     x_Y_z: float = Field(0.0)
예제 #15
0
 class Parent:
     x: int = Field(7)
     child: Child = ComponentField(Child)
예제 #16
0
 class Child:
     x: int = Field()  # Inherited from parent
    class C:
        a: int = Field(0)
        b: float = Field(3.14)

        def __post_configure__(self):
            self.c = self.a + self.b
예제 #18
0
    class ParentTask:
        a: int = Field(2)
        child: Child = ComponentField(Child)

        def run(self):
            print(self.a, self.child.x_Y_z)
 class Parent:
     b: str = Field("foo")
     child: Child = ComponentField(Child)
 class Child:
     a: int = Field()
     b: str = Field()
     c: List[float] = Field()
 class A:
     a: int = Field()
 class GrandParent:
     a: int = Field()
     b: str = Field()
     parent: Parent = ComponentField(Parent)
 class C(B):
     c: int = Field()
 class A:
     a: int = Field()
     b: str = Field("foo")
 class Base:
     a: int = Field()
 class Child:
     a: int = Field()
     b: str = Field()
     c: List[float] = Field()
     d: int = Field(allow_missing=True)
 class Child2(Base):
     a = Field(5)
 class SuperClass:
     foo: str = Field("bar")
 class Parent:
     a: int = Field(5)
     child: Child = ComponentField(Child)
 class Child:
     x: int = Field()