def train_status(self, request: PluginRequest[TrainStatusPluginInput], model: ThirdPartyModel) -> Response[TrainPluginOutput]: """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file.""" logging.debug(f"train {request}") # Create a Response object at the top with a Task attached. This will let us stream back updates # TODO: This is very non-intuitive. We should improve this. response = Response(status=Task(task_id=request.task_id)) # Call train status train_plugin_output = model.train_status(request.data) if train_plugin_output.training_complete: # Save the model with the `default` handle. archive_path_in_steamship = model.save_remote( client=self.client, plugin_instance_id=request.plugin_instance_id) # Set the model location on the plugin output. train_plugin_output.archive_path = archive_path_in_steamship # Set the response on the `data` field of the object response.set_data(json=train_plugin_output) logging.info(response.dict(by_alias=True)) return response
def train(self, request: PluginRequest[TrainPluginInput], model: ThirdPartyModel) -> Response[TrainPluginOutput]: """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file.""" logging.debug(f"train {request}") # Create a Response object at the top with a Task attached. This will let us stream back updates # TODO: This is very non-intuitive. We should improve this. response = Response(status=Task(task_id=request.task_id)) # Train the model train_plugin_input = request.data train_plugin_output = model.train(train_plugin_input) # Set the response on the `data` field of the object response.set_data(json=train_plugin_output) logging.info(response.dict(by_alias=True)) return response
def train(self, request: PluginRequest[TrainPluginInput], model: TestTrainableTaggerModel) -> Response[TrainPluginOutput]: """Since trainable can't be assumed to be asynchronous, the trainer is responsible for uploading its own model file.""" logging.info(f"TestTrainableTaggerPlugin:train() {request}") # Create a Response object at the top with a Task attached. This will let us stream back updates # TODO: This is very non-intuitive. We should improve this. response = Response(status=Task(task_id=request.task_id)) # Example of recording training progress # response.status.status_message = "About to train!" # response.post_update(client=self.client) # Train the model train_plugin_input = request.data train_plugin_output = model.train(train_plugin_input) # Save the model with the `default` handle. archive_path_in_steamship = model.save_remote( client=self.client, plugin_instance_id=request.plugin_instance_id) # Set the model location on the plugin output. logging.info( f"TestTrainableTaggerPlugin:train() setting model archive path to {archive_path_in_steamship}" ) train_plugin_output.archive_path = archive_path_in_steamship # Set the response on the `data` field of the object response.set_data(json=train_plugin_output) # If we want we can post this to the Engine # response.status.status_message = "Done!" # response.status.state = TaskState.succeeded # response.post_update(client=self.client) # Or, if this training really did happen synchronously, we return it. # Some models (e.g. those running on ECS, or on a third party system) will not have completed by the time # the Lambda function finishes. For now, let's just pretend they're synchronous. But in a future PR when we # have a better method of handling such situations, the response below would include a `status` of type `running` # to indicate that, while the plugin handler has returned, the plugin's execution continues. logging.info(f"TestTrainableTaggerPlugin:train() returning {response}") return response
def test_response_post_update_can_update_task(): client = get_steamship_client() task_result = create_dummy_training_task(client) task = task_result.task new_state = TaskState.failed new_message = "HI THERE" new_output = {"a": 3} assert task.state != new_state assert task.status_message != new_message assert task.output != new_output response = Response(status=task) response.status.state = new_state response.status.status_message = new_message response.status.output = new_output # Sanity check: we'll prove that caling task.check() resets this.. task_result.refresh() # Assert not equal assert task.state != new_state assert task.status_message != new_message assert task.output != new_output # And override again response.status.state = new_state response.status.status_message = new_message response.set_data(json=new_output) # Now we call post_update response.post_update(client) # Call task.check task_result.refresh() # Assert equal assert task.state == new_state assert task.status_message == new_message assert task.output == json.dumps(new_output)