def test_params(self, instance: task.Actor, hyperparams): """Params setter/getter tests.""" orig = instance.get_params() assert orig == hyperparams assert 'x' not in orig instance.set_params(x=100) assert instance.get_params()['x'] == 100
def _function(actor: task.Actor, *args) -> bytes: """Consumer objective is the train method. Args: actor: Target actor to run the objective on. *args: List of arguments to be passed to the actor objective. Returns: New actor state. """ actor.train(*args) return actor.get_state()
def state(actor: task.Actor, state: bytes) -> None: """Predefined shifting for state taking objective. Args: actor: Target actor to run the objective on. state: Actor state to be used. Returns: Actor instance. """ LOGGER.debug('%s receiving state (%d bytes)', actor, len(state)) actor.set_state(state)
def params(actor: task.Actor, params: typing.Mapping[str, typing.Any]) -> None: """Predefined shifting for params taking objective. Args: actor: Target actor to run the objective on. params: Actor params to be used. Returns: Actor instance. """ LOGGER.debug('%s receiving params (%s)', actor, params) actor.set_params(**params)
def _function(actor: task.Actor, *args) -> typing.Any: """Mapper objective is the apply method. Args: actor: Target actor to run the objective on. *args: List of arguments to be passed to the actor objective. Returns: Output of the apply method. """ result = actor.apply(*args) LOGGER.debug('%s result: %.1024s...', actor, result) return result
def test_serializable(self, instance: task.Actor, trainset, testset, prediction): """Test actor serializability.""" instance.train(*trainset) assert pickle.loads( pickle.dumps(instance)).predict(testset) == prediction
def test_state(self, instance: task.Actor, trainset, state, testset, prediction): """Testing actor statefulness.""" instance.train(*trainset) assert instance.predict(testset) == prediction assert instance.get_state() == state instance.train('foo', 'bar') # retraining to change the state assert instance.predict(testset) != prediction assert 'x' not in instance.get_params() instance.set_params(x=100) instance.set_state(state) assert instance.get_params( )['x'] == 100 # state shouldn't override parameter setting
def test_train(self, instance: task.Actor, trainset, testset, prediction): """Test actor training.""" assert instance.is_stateful() instance.train(*trainset) assert instance.apply(testset) == prediction