Exemplo n.º 1
0
 def test_params(self, instance: task.Actor, hyperparams):
     """Params setter/getter tests."""
     orig = instance.get_params()
     assert orig == hyperparams
     assert 'x' not in orig
     instance.set_params(x=100)
     assert instance.get_params()['x'] == 100
Exemplo n.º 2
0
    def _function(actor: task.Actor, *args) -> bytes:
        """Consumer objective is the train method.

        Args:
            actor: Target actor to run the objective on.
            *args: List of arguments to be passed to the actor objective.

        Returns:
            New actor state.
        """
        actor.train(*args)
        return actor.get_state()
Exemplo n.º 3
0
        def state(actor: task.Actor, state: bytes) -> None:
            """Predefined shifting for state taking objective.

            Args:
                actor: Target actor to run the objective on.
                state: Actor state to be used.

            Returns:
                Actor instance.
            """
            LOGGER.debug('%s receiving state (%d bytes)', actor, len(state))
            actor.set_state(state)
Exemplo n.º 4
0
        def params(actor: task.Actor,
                   params: typing.Mapping[str, typing.Any]) -> None:
            """Predefined shifting for params taking objective.

            Args:
                actor: Target actor to run the objective on.
                params: Actor params to be used.

            Returns:
                Actor instance.
            """
            LOGGER.debug('%s receiving params (%s)', actor, params)
            actor.set_params(**params)
Exemplo n.º 5
0
    def _function(actor: task.Actor, *args) -> typing.Any:
        """Mapper objective is the apply method.

        Args:
            actor: Target actor to run the objective on.
            *args: List of arguments to be passed to the actor objective.

        Returns:
            Output of the apply method.
        """
        result = actor.apply(*args)
        LOGGER.debug('%s result: %.1024s...', actor, result)
        return result
Exemplo n.º 6
0
 def test_serializable(self, instance: task.Actor, trainset, testset,
                       prediction):
     """Test actor serializability."""
     instance.train(*trainset)
     assert pickle.loads(
         pickle.dumps(instance)).predict(testset) == prediction
Exemplo n.º 7
0
 def test_state(self, instance: task.Actor, trainset, state, testset,
                prediction):
     """Testing actor statefulness."""
     instance.train(*trainset)
     assert instance.predict(testset) == prediction
     assert instance.get_state() == state
     instance.train('foo', 'bar')  # retraining to change the state
     assert instance.predict(testset) != prediction
     assert 'x' not in instance.get_params()
     instance.set_params(x=100)
     instance.set_state(state)
     assert instance.get_params(
     )['x'] == 100  # state shouldn't override parameter setting
Exemplo n.º 8
0
 def test_train(self, instance: task.Actor, trainset, testset, prediction):
     """Test actor training."""
     assert instance.is_stateful()
     instance.train(*trainset)
     assert instance.apply(testset) == prediction