Beispiel #1
0
 def testExecutionWatcher_Local(self):
     req = execution_watcher_pb2.UpdateExecutionInfoRequest()
     value = metadata_store_pb2.Value()
     value.string_value = 'string_value'
     req.execution_id = self._execution.id
     req.updates['test_key'].CopyFrom(value)
     res = self.stub.UpdateExecutionInfo(req)
     self.assertEqual(execution_watcher_pb2.UpdateExecutionInfoResponse(),
                      res)
     with self._mlmd_connection as m:
         executions = m.store.get_executions_by_id([self._execution.id])
     self.assertEqual(len(executions), 1)
     self.assertProtoPartiallyEquals("""
   id: 1
   last_known_state: RUNNING
   custom_properties {
     key: "test_key"
     value {
       string_value: "string_value"
     }
   }
   """,
                                     executions[0],
                                     ignored_fields=[
                                         'type_id',
                                         'create_time_since_epoch',
                                         'last_update_time_since_epoch'
                                     ])
Beispiel #2
0
 def UpdateExecutionInfo(self, req, unused_context):
     """Call back for executor operator to update execution info."""
     # TODO(ericlege): implement this rpc to log updates to MLMD.
     del unused_context
     logging.info(
         'Received request to update execution info: updates %s, '
         'execution_id %s', req.updates, req.execution_id)
     return execution_watcher_pb2.UpdateExecutionInfoResponse()
Beispiel #3
0
 def UpdateExecutionInfo(
     self, request: execution_watcher_pb2.UpdateExecutionInfoRequest,
     context: grpc.ServicerContext
 ) -> execution_watcher_pb2.UpdateExecutionInfoResponse:
     """Updates the `custom_properties` field of Execution object in MLMD."""
     logging.info(
         'Received request to update execution info: updates %s, '
         'execution_id %s', request.updates, request.execution_id)
     if request.execution_id != self._execution.id:
         context.set_code(grpc.StatusCode.NOT_FOUND)
         context.set_details(
             'Execution with given execution_id not tracked by server: '
             f'{request.execution_id}')
         return execution_watcher_pb2.UpdateExecutionInfoResponse()
     for key, value in request.updates.items():
         self._execution.custom_properties[key].CopyFrom(value)
     # Only the execution is needed
     with self._mlmd_connection as m:
         m.store.put_executions((self._execution, ))
     return execution_watcher_pb2.UpdateExecutionInfoResponse()
 def testExecutionWatcher_LocalWithEmptyRequest(self):
     port = portpicker.pick_unused_port()
     sidecar = execution_watcher.ExecutionWatcher(
         port, creds=grpc.local_server_credentials())
     sidecar.start()
     creds = grpc.local_channel_credentials()
     channel = grpc.secure_channel(sidecar.local_address, creds)
     stub = execution_watcher_pb2_grpc.ExecutionWatcherServiceStub(channel)
     req = execution_watcher_pb2.UpdateExecutionInfoRequest()
     res = stub.UpdateExecutionInfo(req)
     sidecar.stop()
     self.assertEqual(execution_watcher_pb2.UpdateExecutionInfoResponse(),
                      res)