コード例 #1
0
    def train_status(self, request: PluginRequest[TrainStatusPluginInput],
                     model: ThirdPartyModel) -> Response[TrainPluginOutput]:
        """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file."""
        logging.debug(f"train {request}")

        # Create a Response object at the top with a Task attached. This will let us stream back updates
        # TODO: This is very non-intuitive. We should improve this.
        response = Response(status=Task(task_id=request.task_id))

        # Call train status
        train_plugin_output = model.train_status(request.data)

        if train_plugin_output.training_complete:
            # Save the model with the `default` handle.
            archive_path_in_steamship = model.save_remote(
                client=self.client,
                plugin_instance_id=request.plugin_instance_id)

            # Set the model location on the plugin output.
            train_plugin_output.archive_path = archive_path_in_steamship

        # Set the response on the `data` field of the object
        response.set_data(json=train_plugin_output)
        logging.info(response.dict(by_alias=True))
        return response
コード例 #2
0
    def train(self, request: PluginRequest[TrainPluginInput],
              model: ThirdPartyModel) -> Response[TrainPluginOutput]:
        """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file."""
        logging.debug(f"train {request}")

        # Create a Response object at the top with a Task attached. This will let us stream back updates
        # TODO: This is very non-intuitive. We should improve this.
        response = Response(status=Task(task_id=request.task_id))

        # Train the model
        train_plugin_input = request.data
        train_plugin_output = model.train(train_plugin_input)

        # Set the response on the `data` field of the object
        response.set_data(json=train_plugin_output)
        logging.info(response.dict(by_alias=True))
        return response
コード例 #3
0
    def train(self, request: PluginRequest[TrainPluginInput],
              model: TestTrainableTaggerModel) -> Response[TrainPluginOutput]:
        """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file."""
        logging.info(f"TestTrainableTaggerPlugin:train() {request}")

        # Create a Response object at the top with a Task attached. This will let us stream back updates
        # TODO: This is very non-intuitive. We should improve this.
        response = Response(status=Task(task_id=request.task_id))

        # Example of recording training progress
        # response.status.status_message = "About to train!"
        # response.post_update(client=self.client)

        # Train the model
        train_plugin_input = request.data
        train_plugin_output = model.train(train_plugin_input)

        # Save the model with the `default` handle.
        archive_path_in_steamship = model.save_remote(
            client=self.client, plugin_instance_id=request.plugin_instance_id)

        # Set the model location on the plugin output.
        logging.info(
            f"TestTrainableTaggerPlugin:train() setting model archive path to {archive_path_in_steamship}"
        )
        train_plugin_output.archive_path = archive_path_in_steamship

        # Set the response on the `data` field of the object
        response.set_data(json=train_plugin_output)

        # If we want we can post this to the Engine
        # response.status.status_message = "Done!"
        # response.status.state = TaskState.succeeded
        # response.post_update(client=self.client)

        # Or, if this training really did happen synchronously, we return it.
        # Some models (e.g. those running on ECS, or on a third party system) will not have completed by the time
        # the Lambda function finishes. For now, let's just pretend they're synchronous. But in a future PR when we
        # have a better method of handling such situations, the response below would include a `status` of type `running`
        # to indicate that, while the plugin handler has returned, the plugin's execution continues.
        logging.info(f"TestTrainableTaggerPlugin:train() returning {response}")
        return response
コード例 #4
0
def test_response_post_update_can_update_task():
    client = get_steamship_client()
    task_result = create_dummy_training_task(client)
    task = task_result.task

    new_state = TaskState.failed
    new_message = "HI THERE"
    new_output = {"a": 3}

    assert task.state != new_state
    assert task.status_message != new_message
    assert task.output != new_output

    response = Response(status=task)

    response.status.state = new_state
    response.status.status_message = new_message
    response.status.output = new_output

    # Sanity check: we'll prove that caling task.check() resets this..
    task_result.refresh()

    # Assert not equal
    assert task.state != new_state
    assert task.status_message != new_message
    assert task.output != new_output

    # And override again
    response.status.state = new_state
    response.status.status_message = new_message
    response.set_data(json=new_output)

    # Now we call post_update
    response.post_update(client)

    # Call task.check
    task_result.refresh()

    # Assert equal
    assert task.state == new_state
    assert task.status_message == new_message
    assert task.output == json.dumps(new_output)