예제 #1
0
 def validate(req):
     assert len(req.docs) == 1
     assert req.docs[0].id == 2
     assert json_format.MessageToDict(req.docs[0].tags)['id'] == 2
예제 #2
0
  def _Project(self, obj, projection, flag, leaf=False):
    """Evaluate() helper function.

    This function takes a resource obj and a preprocessed projection. obj
    is a dense subtree of the resource schema (some keys values may be missing)
    and projection is a sparse, possibly improper, subtree of the resource
    schema. Improper in that it may contain paths that do not exist in the
    resource schema or object. _Project() traverses both trees simultaneously,
    guided by the projection tree. When a projection tree path reaches an
    non-existent obj tree path the projection tree traversal is pruned. When a
    projection tree path terminates with an existing obj tree path, that obj
    tree value is projected and the obj tree traversal is pruned.

    Since resources can be sparse a projection can reference values not present
    in a particular resource. Because of this the code is lenient on out of
    bound conditions that would normally be errors.

    Args:
      obj: An object.
      projection: Projection _Tree node.
      flag: A bitmask of DEFAULT, INNER, PROJECT.
      leaf: Do not call _ProjectAttribute() if True.

    Returns:
      An object containing only the key:values selected by projection, or obj if
      the projection is None or empty.
    """
    objid = id(obj)
    if objid in self._been_here_done_that:
      return None
    elif obj is None:
      pass
    elif isinstance(obj, six.text_type) or isinstance(obj, six.binary_type):
      # Don't use six.string_types because bytes are not considered a string
      # on Python 3.
      if isinstance(obj, six.binary_type):
        # If it's bytes, first decode it, then continue.
        obj = encoding.Decode(obj)
      # Check for {" because valid compact JSON keys are always "..." quoted.
      if (self._json_decode and (
          obj.startswith('{"') and obj.endswith('}') or
          obj.startswith('[') and obj.endswith(']'))):
        try:
          return self._Project(json.loads(obj), projection, flag, leaf=leaf)
        except ValueError:
          # OK if it's not JSON.
          pass
    elif (isinstance(obj, (bool, float, complex)) or
          isinstance(obj, six.integer_types)):
      # primitive data type
      pass
    elif isinstance(obj, bytearray):
      # bytearray copied to disassociate from original obj.
      obj = encoding.Decode(bytes(obj))
    elif isinstance(obj, protorpc_message.Enum):
      # protorpc enum
      obj = obj.name
    else:
      self._been_here_done_that.add(objid)
      if isinstance(obj, protorpc_message.Message):
        # protorpc message
        obj = protorpc_encoding.MessageToDict(obj)
      elif isinstance(obj, protobuf_message.Message):
        # protobuf message
        obj = protobuf_encoding.MessageToDict(obj)
      elif not hasattr(obj, '__iter__') or hasattr(obj, '_fields'):
        # class object or collections.namedtuple() (via the _fields test).
        obj = self._ProjectClass(obj, projection, flag)
      if (projection and projection.attribute and
          projection.attribute.transform and
          self._TransformIsEnabled(projection.attribute.transform)):
        # Transformed nodes prune here.
        obj = projection.attribute.transform.Evaluate(obj)
      elif ((flag >= self._projection.PROJECT or projection and projection.tree)
            and hasattr(obj, '__iter__')):
        if hasattr(obj, 'items'):
          try:
            obj = self._ProjectDict(obj, projection, flag)
          except (IOError, TypeError):
            obj = None
        else:
          try:
            obj = self._ProjectList(obj, projection, flag)
          except (IOError, TypeError):
            obj = None
      self._been_here_done_that.discard(objid)
      return obj
    # _ProjectAttribute() may apply transforms functions on obj, even if it is
    # None. For example, a tranform that returns 'FAILED' for None values.
    return obj if leaf else self._ProjectAttribute(obj, projection, flag)
예제 #3
0
 def testMessageToDict(self):
     message = json_format_proto3_pb2.TestMessage()
     message.int32_value = 12345
     expected = {'int32Value': 12345}
     self.assertEqual(expected, json_format.MessageToDict(message))
예제 #4
0
    def Talk(self, input_text, session_id, model, product_code="", indri_score=0):
        self.GetSdsServer(model)

        empty = empty_pb2.Empty()
        cML = self.sds_stub.GetCurrentModels(empty)
        aML = self.sds_stub.GetAvailableModels(empty)
        dp = sds_pb2.DialogueParam()
        dp.model = model
        #슬롯 값 지정하는 i
       # get_slot = self.DBConnect("select session_name,phone,user_name,join_time,talk_time,insurance_contractor,insurance_insured,insurance_closeproduct,privacy_add1,privacy_add2,insurance_productname from hc_hh_1_score where phone = '" + phoneNum + "';")
      #  if product_code != "":
       #     dp.slots["product_code"] = product_code
       # dp.slots["test"] = "하잉!"
        #dp.slots["슬롯명"] = "슬롯값"


        dp.session_key = session_id
        dp.user_initiative = True

        OpenResult = self.sds_stub.Open(dp)

        sq = sds_pb2.SdsQuery()
        sq.model = dp.model
        sq.session_key = dp.session_key
#        sq.apply_indri_score = indri_score
        sq.utter = input_text

        intent = self.sds_stub.Understand(sq)
        entities = sds_pb2.Entities()
        entities.session_key = dp.session_key
        entities.model = dp.model
        print '<intent>', intent
        print("[ENTITIES]: {}".format(entities))
        sds_utter = self.sds_stub.Generate(entities) # 권장
        #print(sds_utter)
        confidence = sds_utter.confidence

        skill = json_format.MessageToDict(intent).get('filledEntities')
        if skill:
            skill = skill.get('skill')
        if sds_utter.finished == True:
            self.sds_stub.Close(dp)
        #return sds_utter.response.replace('\n', '')
        res = {
            "response":sds_utter.response.replace('\n', ''),
            "intent":sds_utter.system_intent, 
            "intent_only":sds_utter.system_da_type,
            "current_task":sds_utter.current_task,
            "confidence":confidence,
            "best_slu":intent.origin_best_slu,
            "slots":intent.filled_entities,
            "skill":skill,
            "intent.filled_slots.items":intent.filled_entities.items()
        }
        print("--------------------------------------------------")
        print("[TASK]: {}".format(res['current_task']))
        print("[INTENT]: {}".format(res['intent_only']))
        print("[SLOT]")
        for i in res["intent.filled_slots.items"]:
            print("{}: {}".format(i[0], i[1]))
        print("[EMPTY_SLOT]: {}".format(intent.empty_entities))
        print("[CONFIDENCE]: {}".format(res['confidence']))
        print("[RESPONSE]: {}".format(res['response']))
        print("[res]: {}".format(res))
        print("--------------------------------------------------")
        return res
예제 #5
0
 def testScaffoldProgramToSpecs(self):
   expected_custom_job_spec = {
       'name': 'my-custom-job',
       'jobSpec': {
           "workerPoolSpecs": [
               {
                   "replicaCount": 1,
                   "machineSpec": {
                       "machineType": "n1-standard-4"
                   },
                   "containerSpec": {
                       "imageUri": "my_image:latest",
                       "command": [
                           "python",
                           "entrypoint.py"
                       ],
                       "args": [
                           "--input_path",
                           "{{$.inputs.artifacts['examples'].uri}}",
                           "--output_path",
                           "{{$.outputs.artifacts['model'].uri}}",
                           "--optimizer",
                           "{{$.inputs.parameters['optimizer']}}",
                           "--output_param_path",
                           "{{$.outputs.parameters['out_param'].output_file}}"
                       ]
                   }
               },
               {
                   "replicaCount": 4,
                   "containerSpec": {
                       "imageUri": "gcr.io/my-project/my-worker-image:latest",
                       "command": [
                           "python3",
                           "override_entrypoint.py"
                       ],
                       "args": [
                           "--arg1",
                           "param1"
                       ]
                   },
                   "machineSpec": {
                       "machineType": "n1-standard-8",
                       "acceleratorType": "NVIDIA_TESLA_K80",
                       "acceleratorCount": 1
                   }
               }
           ]
       }}
   task = aiplatform.custom_job(
       name='my-custom-job',
       input_artifacts={
           'examples': dsl.PipelineParam(
               name='output',
               op_name='ingestor',
               param_type='Dataset')},
       input_parameters={'optimizer': 'sgd'},
       output_artifacts={
           'model': ontology_artifacts.Model},
       output_parameters={
           'out_param': str},
       image_uri='my_image:latest',
       commands=['python', 'entrypoint.py'],
       args=[
           '--input_path', structures.InputUriPlaceholder('examples'),
           '--output_path', structures.OutputUriPlaceholder('model'),
           '--optimizer', structures.InputValuePlaceholder('optimizer'),
           '--output_param_path',
           structures.OutputPathPlaceholder('out_param')
       ],
       additional_job_spec={
           'workerPoolSpecs': [
               {
                   'replicaCount': 1,
                   'machineSpec': {'machineType': 'n1-standard-4'}
               },
               {
                   'replicaCount': 4,
                   'containerSpec': {
                       'imageUri': 'gcr.io/my-project/my-worker-image:latest',
                       'command': ['python3', 'override_entrypoint.py'],
                       'args': ['--arg1', 'param1']
                   },
                   # Optionally one can also attach accelerators.
                   'machineSpec': {
                       'machineType': 'n1-standard-8',
                       'acceleratorType': 'NVIDIA_TESLA_K80',
                       'acceleratorCount': 1
                   }}]
       }
   )
   self.assertDictEqual(expected_custom_job_spec, task.custom_job_spec)
   self.assertDictEqual(_EXPECTED_COMPONENT_SPEC ,
                        json_format.MessageToDict(task.component_spec))
   self.assertDictEqual(_EXPECTED_TASK_SPEC,
                        json_format.MessageToDict(task.task_spec))
예제 #6
0
 def _json_dict_(self) -> Dict[str, Any]:
     """Magic method for the JSON serialization protocol."""
     return {
         'cirq_type': 'Calibration',
         'metrics': json_format.MessageToDict(self.to_proto())
     }
예제 #7
0
  def _create_pipeline_v2(
      self,
      pipeline_func: Callable[..., Any],
      pipeline_name: Optional[str] = None,
      pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
  ) -> pipeline_spec_pb2.PipelineJob:
    """Creates a pipeline instance and constructs the pipeline spec from it.

    Args:
      pipeline_func: Pipeline function with @dsl.pipeline decorator.
      pipeline_name: The name of the pipeline. Optional.
      pipeline_parameters_override: The mapping from parameter names to values.
        Optional.

    Returns:
      A PipelineJob proto representing the compiled pipeline.
    """

    # Create the arg list with no default values and call pipeline function.
    # Assign type information to the PipelineParam
    pipeline_meta = _python_op._extract_component_interface(pipeline_func)
    pipeline_name = pipeline_name or pipeline_meta.name

    pipeline_root = getattr(pipeline_func, 'pipeline_root', None)

    args_list = []
    signature = inspect.signature(pipeline_func)
    for arg_name in signature.parameters:
      arg_type = None
      for pipeline_input in pipeline_meta.inputs or []:
        if arg_name == pipeline_input.name:
          arg_type = pipeline_input.type
          break
      args_list.append(
          dsl.PipelineParam(
              sanitize_k8s_name(arg_name, True), param_type=arg_type))

    with dsl.Pipeline(pipeline_name) as dsl_pipeline:
      pipeline_func(*args_list)

    self._sanitize_and_inject_artifact(dsl_pipeline)

    # Fill in the default values.
    args_list_with_defaults = []
    if pipeline_meta.inputs:
      args_list_with_defaults = [
          dsl.PipelineParam(
              sanitize_k8s_name(input_spec.name, True),
              param_type=input_spec.type,
              value=input_spec.default) for input_spec in pipeline_meta.inputs
      ]

    # Making the pipeline group name unique to prevent name clashes with templates
    pipeline_group = dsl_pipeline.groups[0]
    temp_pipeline_group_name = uuid.uuid4().hex
    pipeline_group.name = temp_pipeline_group_name

    pipeline_spec = self._create_pipeline_spec(
        args_list_with_defaults,
        dsl_pipeline,
    )

    pipeline_parameters = {
        param.name: param for param in args_list_with_defaults
    }
    # Update pipeline parameters override if there were any.
    pipeline_parameters_override = pipeline_parameters_override or {}
    for k, v in pipeline_parameters_override.items():
      if k not in pipeline_parameters:
        raise ValueError('Pipeline parameter {} does not match any known '
                         'pipeline argument.'.format(k))
      pipeline_parameters[k].value = v

    runtime_config = compiler_utils.build_runtime_config_spec(
        output_directory=pipeline_root, pipeline_parameters=pipeline_parameters)
    pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
    pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

    return pipeline_job
예제 #8
0
    def _get_parameters(self, model: Any) -> 'OrderedDict[str, Optional[str]]':
        def _to_ordered(o):
            if isinstance(o, dict):
                # If o is a dictionary, recursively sort its values, and then sort by keys
                for (key, val) in o.items():
                    if isinstance(val, dict) or isinstance(val, list):
                        o[key] = _to_ordered(val)
                result = OrderedDict(sorted(o.items(), key=lambda x: x[0]))
            elif isinstance(o, list):
                # If o is a list, recursively sort items which are dictionaries
                result = []
                for item in o:
                    if isinstance(item, dict):
                        result.append(_to_ordered(item))
                    else:
                        result.append(item)
            return result

        # Convert the protobuf to python dictionary
        model_dic = json_format.MessageToDict(model)  # type: Dict[str, Any]

        # Initialize parameters dictionary
        parameters = {'backend': {}}  # type: Dict[str, Any]

        # Add graph information to parameters dictionary
        for key, value in sorted(model_dic['graph'].items(),
                                 key=lambda t: t[0]):
            if isinstance(value, list):
                for (index, val) in enumerate(value):
                    k = '{}_{}_{}'.format(key, str(index), val['name'])
                    v = val
                    if key == 'initializer':
                        # Determine the name of the attribute containing the data
                        if isinstance(v['dataType'], int):
                            data_key = ONNX_ATTR_TYPES[
                                v['dataType']].lower() + 'Data'
                        elif isinstance(v['dataType'], str):
                            data_key = v['dataType'].lower() + 'Data'
                        else:
                            raise ValueError(
                                'Unknown data type. Try downgrading ONNX to 1.2.1.'
                            )
                        # Remove data from initializer as the model
                        # will be reinitialized after deserialization
                        del v[data_key]
                    if isinstance(v, Dict):
                        v = _to_ordered(v)
                    parameters[k] = json.dumps(v)
            else:
                parameters[key] = value

        # Add backend information to parameters dictionary
        for key, value in model_dic.items():
            parameters['backend'][key] = value

        # Remove redundant graph information
        del parameters['backend']['graph']

        parameters['backend'] = json.dumps(_to_ordered(parameters['backend']))

        # Sort the parameters dictionary as expected by OpenML
        parameters_ordered = OrderedDict(sorted(parameters.items(), key=lambda x: x[0]))  \
            # type: OrderedDict[str, Optional[str]]

        return parameters_ordered
예제 #9
0
def scheduler():
    mesos_handler = MesosHttp.get_mesos_handler()
    request_debug("/api/v1/scheduler", request)
    try:
        if request.method == 'POST':
            logger.info("POST")
            content = {}
            if request.is_json:
                logger.debug("json data")
                content = request.get_json()
                ctype = content['type']
            else:
                logger.debug("protobuf data")
                call = scheduler_pb2.Call()
                call.ParseFromString(request.data)
                #                content = protobuf_to_dict(call, use_enum_labels=True)
                content = json_format.MessageToDict(
                    call, preserving_proto_field_name=True)
            logger.info("content=%s" % content)
            if content['type'] == 'SUBSCRIBE':
                if "Mesos-Stream-Id" in request.headers:
                    msg = "Subscribe calls should not include the 'Mesos-Stream-Id' header"
                    length = len(msg)
                    buf = str(length) + "\n" + msg
                    resp = Response(buf, status=400)
                    return resp

                def generate():
                    try:
                        mesos_subscribed = mesos_handler.http_subscribe(
                            content['subscribe']['framework_info'])
                        framework_name = mesos_subscribed['payload'][
                            'framework_id']['value']
                        logger.debug("Subscribed framework_name=%s, yield it" %
                                     framework_name)
                        yield framework_name
                        master_info = mesos_subscribed['payload'][
                            'master_info']
                        #                        master_info['port'] = mesos_handler.get_http_port() # set actual http port (instead of redis port)
                        #                        master_info['address']['port'] = master_info['port']

                        #                        if 'ip' in master_info:
                        #                            del master_info['ip'] # might be incorrect

                        subscribed_json = json.dumps({
                            'type': 'SUBSCRIBED',
                            'subscribed': {
                                'framework_id': {
                                    'value': framework_name
                                },
                                'heartbeat_interval_seconds':
                                mesos_handler.get_heartbeat_interval(),
                                'master_info':
                                master_info
                            }
                        })
                        if request.is_json:
                            subscribed = subscribed_json
                            logger.debug("json subscribed response")
                        else:
                            logger.debug("protobuf subscribed response")
                            subscribed_msg = json_format.Parse(
                                subscribed_json,
                                scheduler_pb2.Event(),
                                ignore_unknown_fields=False)
                            subscribed = subscribed_msg.SerializeToString()

                        length = len(subscribed)
                        buf = str(length) + "\n" + subscribed
                        logger.trace("subscribed=%s, yield it as recordio" %
                                     subscribed)
                        yield buf

                        logger.debug("subscribed before generate offer loop")
                        for offers in mesos_handler.http_generate_offers(
                            {'value': framework_name}):
                            if offers:
                                if request.is_json:
                                    resp_event = json.dumps(offers)
                                    logger.debug("json offer response")
                                else:
                                    offer_msg = json_format.Parse(
                                        json.dumps(offers),
                                        scheduler_pb2.Event(),
                                        ignore_unknown_fields=False)
                                    resp_event = offer_msg.SerializeToString()
                                    logger.debug("protobuf offer response")
                                length = len(resp_event)
                                buf = str(length) + "\n" + resp_event
                                logger.trace(
                                    "offer: %s, yield it as recordio" %
                                    resp_event)
                                yield buf
                            else:
                                logger.debug(
                                    "in offer loop: skip empty offers")
                    except Exception as ge:
                        logger.error("Exception in generator: %s" % ge)

                g = generate()
                framework_name = next(g)
                logger.debug("generated framework_name=%s" % framework_name)
                mimetype = "application/json" if request.is_json else "application/x-protobuf"
                #                content_type = "application/json" if is_json_response else "application/x-protobuf"
                resp = Response(stream_with_context(g),
                                status=200,
                                mimetype=mimetype)
                resp.headers['Mesos-Stream-Id'] = framework_name
                logger.debug("resp.headers=%s" % resp.headers)
                return resp

            elif content['type'] == 'TEARDOWN':
                logger.debug("TEARDOWN")
                framework_id = content['framework_id']
                mesos_handler.http_teardown(framework_id)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'ACCEPT':
                logger.debug("ACCEPT")
                framework_id = content['framework_id']
                offer_ids = content['accept']['offer_ids']
                filters = content['accept'].get('filters')
                if content['accept'].get('operations') and len(
                        content['accept']['operations']) > 0:
                    for operation in content['accept']['operations']:
                        if operation['type'] == 'LAUNCH':
                            # convert ranges to integers if necessary
                            for task in operation['launch'].get(
                                    'task_infos', []):
                                for resource in task.get('resources', []):
                                    for port_range in resource.get(
                                            'ranges', {}).get('range', []):
                                        b = port_range['begin']
                                        if isinstance(b, six.string_types):
                                            port_range['begin'] = int(b)
                                        e = port_range['end']
                                        if isinstance(e, six.string_types):
                                            port_range['end'] = int(e)
                            mesos_handler.http_accept_launch(
                                framework_id, offer_ids, filters,
                                operation['launch'])
                        else:
                            msg = "ACCEPT: operation type %s is not supported" % content[
                                'accept']['operations']['type']
                            resp = Response(msg, status=400)
                            return resp
                else:
                    # same as decline
                    mesos_handler.http_decline(framework_id, offer_ids,
                                               filters)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'DECLINE':
                logger.debug("DECLINE")
                framework_id = content['framework_id']
                offer_ids = content['decline']['offer_ids']
                filters = content['decline'].get('filters')
                mesos_handler.http_decline(framework_id, offer_ids, filters)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'REVIVE':
                logger.debug("REVIVE")
                framework_id = content['framework_id']
                revive = content['revive']
                mesos_handler.http_revive(framework_id, revive)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'KILL':
                logger.debug("KILL")
                framework_id = content['framework_id']
                task_id = content['kill']['task_id']
                agent_id = content['kill'].get('agent_id')
                mesos_handler.http_kill(framework_id, task_id, agent_id)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'SHUTDOWN':
                logger.debug("SHUTDOWN")
                framework_id = content['framework_id']
                executor_id = content['shutdown']['executor_id']
                agent_id = content['shutdown']['agent_id']
                mesos_handler.http_shutdown(framework_id, executor_id,
                                            agent_id)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'ACKNOWLEDGE':
                logger.debug("ACKNOWLEDGE")
                framework_id = content['framework_id']
                agent_id = content['acknowledge']['agent_id']
                task_id = content['acknowledge']['task_id']
                uuid = content['acknowledge'].get('uuid')
                mesos_handler.http_acknowledge(framework_id, agent_id, task_id,
                                               uuid)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'RECONCILE':
                logger.debug("RECONCILE")
                framework_id = content['framework_id']
                tasks = content.get('reconcile', {}).get('tasks', [])
                mesos_handler.http_reconcile(framework_id, tasks)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'MESSAGE':
                logger.debug("MESSAGE")
                framework_id = content['framework_id']
                agent_id = content['message']['agent_id']
                executor_id = content['message']['executor_id']
                data = content['message']['data']
                mesos_handler.http_message(framework_id, agent_id, executor_id,
                                           data)
                resp = Response(status=202)
                return resp

            elif content['type'] == 'REQUEST':
                logger.debug("REQUEST")
                framework_id = content['framework_id']
                agent_id = content['requests']['agent_id']
                resources = content['requests']['resources']
                mesos_handler.http_request(framework_id, agent_id, resources)
                resp = Response(status=202)
                return resp

            else:
                logger.error("Unkown content type: %s" % content['type'])
        else:
            logger.info("scheduler: unxpected request method: %s" %
                        request.method)
    except Exception as e:
        msg = "Exception in scheduler endpoint: %s" % e
        logger.error(msg)
        return Response(msg, status=500)
예제 #10
0
    def update_from_vulnerability(self, vulnerability):
        """Set fields from vulnerability. Does not set the ID."""
        self.summary = vulnerability.summary
        self.details = vulnerability.details
        self.reference_url_types = {
            ref.url: vulnerability_pb2.Reference.Type.Name(ref.type)
            for ref in vulnerability.references
        }

        if vulnerability.HasField('modified'):
            self.last_modified = vulnerability.modified.ToDatetime()
        if vulnerability.HasField('published'):
            self.timestamp = vulnerability.published.ToDatetime()
        if vulnerability.HasField('withdrawn'):
            self.withdrawn = vulnerability.withdrawn.ToDatetime()

        self.aliases = list(vulnerability.aliases)
        self.related = list(vulnerability.related)

        if not vulnerability.affected:
            self._update_from_pre_0_8(vulnerability)
            return

        self.affected_packages = []
        for affected_package in vulnerability.affected:
            current = AffectedPackage()
            current.package = Package(
                name=affected_package.package.name,
                ecosystem=affected_package.package.ecosystem,
                purl=affected_package.package.purl)
            current.ranges = []

            for affected_range in affected_package.ranges:
                current_range = AffectedRange2(
                    type=vulnerability_pb2.Range.Type.Name(
                        affected_range.type),
                    repo_url=affected_range.repo,
                    events=[])

                for evt in affected_range.events:
                    if evt.introduced:
                        current_range.events.append(
                            AffectedEvent(type='introduced',
                                          value=evt.introduced))
                        continue

                    if evt.fixed:
                        current_range.events.append(
                            AffectedEvent(type='fixed', value=evt.fixed))
                        continue

                    if evt.limit:
                        current_range.events.append(
                            AffectedEvent(type='limit', value=evt.limit))
                        continue

                current.ranges.append(current_range)

            current.versions = list(affected_package.versions)
            if affected_package.database_specific:
                current.database_specific = json_format.MessageToDict(
                    affected_package.database_specific,
                    preserving_proto_field_name=True)

            if affected_package.ecosystem_specific:
                current.ecosystem_specific = json_format.MessageToDict(
                    affected_package.ecosystem_specific,
                    preserving_proto_field_name=True)

            self.affected_packages.append(current)
예제 #11
0
async def parse_async_stream_request(request_iterator):
    async for req in request_iterator:
        msg = StreamMessage(
            **json_format.MessageToDict(req, preserving_proto_field_name=True))
        msg.raw_data = req
        yield msg
예제 #12
0
def _ToDict(rdfval: rdf_structs.RDFProtoStruct) -> JsonDict:
    return json_format.MessageToDict(rdfval.AsPrimitiveProto(),
                                     float_precision=8)
예제 #13
0
 def data_type(self, value):
     if value:
         resource = json_format.MessageToDict(value._pb)
     else:
         resource = None
     self._properties[self._PROPERTY_TO_API_FIELD["data_type"]] = resource
예제 #14
0
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

confz = SparkConf()\
.set("spark.hadoop.fs.s3a.endpoint","http://127.0.0.1:9000")\
.set("spark.hadoop.fs.s3a.access.key","minio")\
.set("spark.hadoop.fs.s3a.secret.key","minio123")\
.set("spark.hadoop.fs.s3a.path.style.access","true")\
.set("spark.hadoop.fs.s3a.impl","org.apache.hadoop.fs.s3a.S3AFileSystem")

spark = SparkSession.builder.master("local[3]").appName("Test4").config(
    conf=confz).getOrCreate()

address_book = Addressbook_pb2.AddressBook()

if len(sys.argv) != 2:
    print("Usage:", sys.argv[0], "ADDRESS_BOOK_FILE")
    sys.exit(-1)

address_book = Addressbook_pb2.AddressBook()

with open(sys.argv[1], "rb") as f:
    address_book.ParseFromString(f.read())

message = address_book
json_string = json_format.MessageToDict(message)
rddData = spark.sparkContext.parallelize(json_string)
jsonframe = spark.read.json(rddData)
jsonframe.write.mode("overwrite").format("json").save(
    "s3a://spark-test/jsonconvert")
예제 #15
0
def convert_metrics_proto_to_dict(
    metrics_for_slice: metrics_for_slice_pb2.MetricsForSlice,
    model_name: Optional[Text] = None
) -> Tuple[slicer.SliceKeyType, Optional[view_types.MetricsByOutputName]]:
    """Converts metrics proto to dict."""
    model_metrics_map = {}
    if metrics_for_slice.metrics:
        model_metrics_map[''] = {
            '': {
                '': _convert_proto_map_to_dict(metrics_for_slice.metrics)
            }
        }

    default_model_name = None
    if metrics_for_slice.metric_keys_and_values:
        for kv in metrics_for_slice.metric_keys_and_values:
            current_model_name = kv.key.model_name
            if current_model_name not in model_metrics_map:
                model_metrics_map[current_model_name] = {}
            output_name = kv.key.output_name
            if output_name not in model_metrics_map[current_model_name]:
                model_metrics_map[current_model_name][output_name] = {}

            sub_key_metrics_map = model_metrics_map[current_model_name][
                output_name]
            sub_key_id = str(metric_types.SubKey.from_proto(
                kv.key.sub_key)) if kv.key.HasField('sub_key') else ''
            if sub_key_id not in sub_key_metrics_map:
                sub_key_metrics_map[sub_key_id] = {}
            if kv.key.is_diff:
                if default_model_name is None:
                    default_model_name = current_model_name
                elif default_model_name != current_model_name:
                    # Setting '' to trigger no match found ValueError below.
                    default_model_name = ''
                metric_name = '{}_diff'.format(kv.key.name)
            else:
                metric_name = kv.key.name
            sub_key_metrics_map[sub_key_id][
                metric_name] = json_format.MessageToDict(kv.value)

    metrics_map = None
    keys = list(model_metrics_map.keys())
    tmp_model_name = model_name or default_model_name
    if tmp_model_name in model_metrics_map:
        # Use the provided model name if there is a match.
        metrics_map = model_metrics_map[tmp_model_name]
        # Add model-independent (e.g. example_count) metrics to all models.
        if tmp_model_name and '' in model_metrics_map:
            for output_name, output_dict in model_metrics_map[''].items():
                for sub_key_id, sub_key_dict in output_dict.items():
                    for name, value in sub_key_dict.items():
                        metrics_map.setdefault(output_name,
                                               {}).setdefault(sub_key_id,
                                                              {})[name] = value
    elif not tmp_model_name and len(keys) == 1:
        # Show result of the only model if no model name is specified.
        metrics_map = model_metrics_map[keys[0]]
    elif keys:
        # No match found.
        raise ValueError('Fail to find metrics for model name: %s . '
                         'Available model names are [%s]' %
                         (model_name, ', '.join(keys)))

    return (slicer.deserialize_slice_key(metrics_for_slice.slice_key),
            metrics_map)
예제 #16
0
def pb_to_dict_converter(msg, primary_key=None):
    d = json_format.MessageToDict(msg, preserving_proto_field_name=True)
    if primary_key:
        d['_id'] = d[primary_key]
        del d[primary_key]
    return d
예제 #17
0
 def parse_proto_message(self, message):
     return json_format.MessageToDict(message, **self.__decode_options__)
예제 #18
0
    def register(self):
        """
        Registers a user by creating a subscription.

        Registration is pretty straightforward for now, since it does not require payments.
        The amount of slots and expiry of the subscription cannot be requested by the user yet either. This is linked to
        the previous point.
        Users register by sending a public key to the proper endpoint. This is exploitable atm, but will be solved when
        payments are introduced.

        Returns:
            :obj:`tuple`: A tuple containing the response (:obj:`str`) and response code (:obj:`int`). For accepted
            requests, the ``rcode`` is always 200 and the response contains a json with the public key and number of
            slots in the subscription. For rejected requests, the ``rcode`` is a 404 and the value contains an
            application error, and an error message. Error messages can be found at ``common.errors``.
        """

        remote_addr = get_remote_addr()
        self.logger.info("Received register request",
                         from_addr="{}".format(remote_addr))

        # Check that data type and content are correct. Abort otherwise.
        try:
            request_data = get_request_data_json(request)

        except InvalidParameter as e:
            self.logger.info("Received invalid register request",
                             from_addr="{}".format(remote_addr))
            return jsonify({
                "error": str(e),
                "error_code": errors.INVALID_REQUEST_FORMAT
            }), HTTP_BAD_REQUEST

        user_id = request_data.get("public_key")

        if user_id:
            try:
                r = self.stub.register(RegisterRequest(user_id=user_id))

                rcode = HTTP_OK
                response = json_format.MessageToDict(
                    r,
                    including_default_value_fields=True,
                    preserving_proto_field_name=True)
                response["public_key"] = user_id

            except grpc.RpcError as e:
                rcode = HTTP_BAD_REQUEST
                response = {
                    "error": e.details(),
                    "error_code": errors.REGISTRATION_MISSING_FIELD
                }

        else:
            rcode = HTTP_BAD_REQUEST
            response = {
                "error": "public_key not found in register message",
                "error_code": errors.REGISTRATION_WRONG_FIELD_FORMAT,
            }

        self.logger.info("Sending response and disconnecting",
                         from_addr="{}".format(remote_addr),
                         response=response)

        return jsonify(response), rcode
예제 #19
0
  def _group_to_dag_spec(
      self,
      group: dsl.OpsGroup,
      inputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
      outputs: Dict[str, List[Tuple[dsl.PipelineParam, str]]],
      dependencies: Dict[str, List[_GroupOrOp]],
      pipeline_spec: pipeline_spec_pb2.PipelineSpec,
      deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
      rootgroup_name: str,
      op_to_parent_groups: Dict[str, List[str]],
  ) -> None:
    """Generate IR spec given an OpsGroup.

    Args:
      group: The OpsGroup to generate spec for.
      inputs: The inputs dictionary. The keys are group/op names and values are
        lists of tuples (param, producing_op_name).
      outputs: The outputs dictionary. The keys are group/op names and values
        are lists of tuples (param, producing_op_name).
      dependencies: The group dependencies dictionary. The keys are group/op
        names, and the values are lists of dependent groups/ops.
      pipeline_spec: The pipeline_spec to update in-place.
      deployment_config: The deployment_config to hold all executors.
      rootgroup_name: The name of the group root. Used to determine whether the
        component spec for the current group should be the root dag.
      op_to_parent_groups: The dict of op name to parent groups. Key is the op's
        name. Value is a list of ancestor groups including the op itself. The
        list of a given op is sorted in a way that the farthest group is the
        first and the op itself is the last.
    """
    group_component_name = dsl_utils.sanitize_component_name(group.name)

    if group.name == rootgroup_name:
      group_component_spec = pipeline_spec.root
    else:
      group_component_spec = pipeline_spec.components[group_component_name]

    # Generate task specs and component specs for the dag.
    subgroups = group.groups + group.ops
    for subgroup in subgroups:
      subgroup_task_spec = getattr(subgroup, 'task_spec',
                                   pipeline_spec_pb2.PipelineTaskSpec())
      subgroup_component_spec = getattr(subgroup, 'component_spec',
                                        pipeline_spec_pb2.ComponentSpec())

      is_recursive_subgroup = (
          isinstance(subgroup, dsl.OpsGroup) and subgroup.recursive_ref)

      # Special handling for recursive subgroup: use the existing opsgroup name
      if is_recursive_subgroup:
        subgroup_key = subgroup.recursive_ref.name
      else:
        subgroup_key = subgroup.name

      subgroup_task_spec.task_info.name = (
          subgroup_task_spec.task_info.name or
          dsl_utils.sanitize_task_name(subgroup_key))
      # human_name exists for ops only, and is used to de-dupe component spec.
      subgroup_component_name = (
          subgroup_task_spec.component_ref.name or
          dsl_utils.sanitize_component_name(
              getattr(subgroup, 'human_name', subgroup_key)))
      subgroup_task_spec.component_ref.name = subgroup_component_name

      if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'graph':
        raise NotImplementedError(
            'dsl.graph_component is not yet supported in KFP v2 compiler.')

      if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'exit_handler':
        raise NotImplementedError(
            'dsl.ExitHandler is not yet supported in KFP v2 compiler.')

      if isinstance(subgroup, dsl.ContainerOp):
        if hasattr(subgroup, 'importer_spec'):
          importer_task_name = subgroup.task_spec.task_info.name
          importer_comp_name = subgroup.task_spec.component_ref.name
          importer_exec_label = subgroup.component_spec.executor_label
          group_component_spec.dag.tasks[importer_task_name].CopyFrom(
              subgroup.task_spec)
          pipeline_spec.components[importer_comp_name].CopyFrom(
              subgroup.component_spec)
          deployment_config.executors[importer_exec_label].importer.CopyFrom(
              subgroup.importer_spec)

      subgroup_inputs = inputs.get(subgroup.name, [])
      subgroup_params = [param for param, _ in subgroup_inputs]

      tasks_in_current_dag = [
          dsl_utils.sanitize_task_name(subgroup.name) for subgroup in subgroups
      ]

      input_parameters_in_current_dag = [
          input_name
          for input_name in group_component_spec.input_definitions.parameters
      ]
      input_artifacts_in_current_dag = [
          input_name
          for input_name in group_component_spec.input_definitions.artifacts
      ]

      is_parent_component_root = group_component_spec == pipeline_spec.root

      if isinstance(subgroup, dsl.ContainerOp):
        dsl_component_spec.update_task_inputs_spec(
            subgroup_task_spec,
            group_component_spec.input_definitions,
            subgroup_params,
            tasks_in_current_dag,
            input_parameters_in_current_dag,
            input_artifacts_in_current_dag,
        )

      if isinstance(subgroup, dsl.ParallelFor):
        if subgroup.parallelism is not None:
          warnings.warn(
              'Setting parallelism in ParallelFor is not supported yet.'
              'The setting is ignored.')

        # "Punch the hole", adding additional inputs (other than loop arguments
        # which will be handled separately) needed by its subgroup or tasks.
        loop_subgroup_params = []
        for param in subgroup_params:
          if isinstance(
              param, (_for_loop.LoopArguments, _for_loop.LoopArgumentVariable)):
            continue
          loop_subgroup_params.append(param)

        if subgroup.items_is_pipeline_param:
          # This loop_args is a 'withParam' rather than a 'withItems'.
          # i.e., rather than a static list, it is either the output of
          # another task or an input as global pipeline parameters.
          loop_subgroup_params.append(
              subgroup.loop_args.items_or_pipeline_param)

        dsl_component_spec.build_component_inputs_spec(
            component_spec=subgroup_component_spec,
            pipeline_params=loop_subgroup_params,
            is_root_component=False,
        )
        dsl_component_spec.build_task_inputs_spec(
            subgroup_task_spec,
            loop_subgroup_params,
            tasks_in_current_dag,
            is_parent_component_root,
        )

        if subgroup.items_is_pipeline_param:
          input_parameter_name = (
              dsl_component_spec.additional_input_name_for_pipelineparam(
                  subgroup.loop_args.items_or_pipeline_param))
          loop_arguments_item = '{}-{}'.format(
              input_parameter_name, _for_loop.LoopArguments.LOOP_ITEM_NAME_BASE)

          subgroup_component_spec.input_definitions.parameters[
              loop_arguments_item].type = pipeline_spec_pb2.PrimitiveType.STRING
          subgroup_task_spec.parameter_iterator.items.input_parameter = (
              input_parameter_name)
          subgroup_task_spec.parameter_iterator.item_input = (
              loop_arguments_item)

          # If the loop arguments itself is a loop arguments variable, handle
          # the subvar name.
          loop_args_name, subvar_name = (
              dsl_component_spec._exclude_loop_arguments_variables(
                  subgroup.loop_args.items_or_pipeline_param))
          if subvar_name:
            subgroup_task_spec.inputs.parameters[
                input_parameter_name].parameter_expression_selector = (
                    'parseJson(string_value)["{}"]'.format(subvar_name))
            subgroup_task_spec.inputs.parameters[
                input_parameter_name].component_input_parameter = (
                    dsl_component_spec.additional_input_name_for_pipelineparam(
                        loop_args_name))

        else:
          input_parameter_name = (
              dsl_component_spec.additional_input_name_for_pipelineparam(
                  subgroup.loop_args.full_name))
          raw_values = subgroup.loop_args.to_list_for_task_yaml()

          subgroup_component_spec.input_definitions.parameters[
              input_parameter_name].type = pipeline_spec_pb2.PrimitiveType.STRING
          subgroup_task_spec.parameter_iterator.items.raw = json.dumps(
              raw_values, sort_keys=True)
          subgroup_task_spec.parameter_iterator.item_input = (
              input_parameter_name)

      if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'condition':

        # "punch the hole", adding inputs needed by its subgroup or tasks.
        dsl_component_spec.build_component_inputs_spec(
            component_spec=subgroup_component_spec,
            pipeline_params=subgroup_params,
            is_root_component=False,
        )
        dsl_component_spec.build_task_inputs_spec(
            subgroup_task_spec,
            subgroup_params,
            tasks_in_current_dag,
            is_parent_component_root,
        )

        condition = subgroup.condition
        operand_values = []

        operand1_value, operand2_value = self._resolve_condition_operands(
            condition.operand1, condition.operand2)

        condition_string = '{} {} {}'.format(operand1_value, condition.operator,
                                             operand2_value)

        subgroup_task_spec.trigger_policy.CopyFrom(
            pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(
                condition=condition_string))

      # Generate dependencies section for this task.
      if dependencies.get(subgroup.name, None):
        group_dependencies = list(dependencies[subgroup.name])
        group_dependencies.sort()
        subgroup_task_spec.dependent_tasks.extend(
            [dsl_utils.sanitize_task_name(dep) for dep in group_dependencies])

      # Add component spec if not exists
      if subgroup_component_name not in pipeline_spec.components:
        pipeline_spec.components[subgroup_component_name].CopyFrom(
            subgroup_component_spec)

      # Add task spec
      group_component_spec.dag.tasks[
          subgroup_task_spec.task_info.name].CopyFrom(subgroup_task_spec)

      # Add AIPlatformCustomJobSpec, if applicable.
      custom_job_spec = getattr(subgroup, 'custom_job_spec', None)
      if custom_job_spec:
        executor_label = subgroup_component_spec.executor_label
        if executor_label not in deployment_config.executors:
          deployment_config.executors[
              executor_label].custom_job.custom_job.update(custom_job_spec)

      # Add executor spec, if applicable.
      container_spec = getattr(subgroup, 'container_spec', None)
      # Ignore contaienr_spec if custom_job_spec exists.
      if container_spec and not custom_job_spec:
        if compiler_utils.is_v2_component(subgroup):
          compiler_utils.refactor_v2_container_spec(container_spec)
        executor_label = subgroup_component_spec.executor_label

        if executor_label not in deployment_config.executors:
          deployment_config.executors[executor_label].container.CopyFrom(
              container_spec)

    pipeline_spec.deployment_spec.update(
        json_format.MessageToDict(deployment_config))

    # Surface metrics outputs to the top.
    self._populate_metrics_in_dag_outputs(
        group.ops,
        op_to_parent_groups,
        pipeline_spec,
    )
예제 #20
0
    def add_appointment(self):
        """
        Main endpoint of the Watchtower.

        The client sends requests (appointments) to this endpoint to request a job to the Watchtower. Requests must be
        json encoded and contain an ``appointment`` and ``signature`` fields.

        Returns:
            :obj:`tuple`: A tuple containing the response (:obj:`str`) and response code (:obj:`int`). For accepted
            appointments, the ``rcode`` is always 200 and the response contains the receipt signature (json). For
            rejected appointments, the ``rcode`` contains an application error, and an error message. Error messages can
            be found at ``common.errors``.
        """

        # Getting the real IP if the server is behind a reverse proxy
        remote_addr = get_remote_addr()
        self.logger.info("Received add_appointment request",
                         from_addr="{}".format(remote_addr))

        # Check that data type and content are correct. Abort otherwise.
        try:
            request_data = get_request_data_json(request)

        except InvalidParameter as e:
            return jsonify({
                "error": str(e),
                "error_code": errors.INVALID_REQUEST_FORMAT
            }), HTTP_BAD_REQUEST

        try:
            appointment = self.inspector.inspect(
                request_data.get("appointment"))
            r = self.stub.add_appointment(
                AddAppointmentRequest(
                    appointment=Appointment(
                        locator=appointment.locator,
                        encrypted_blob=appointment.encrypted_blob,
                        to_self_delay=appointment.to_self_delay,
                    ),
                    signature=request_data.get("signature"),
                ))

            rcode = HTTP_OK
            response = json_format.MessageToDict(
                r,
                including_default_value_fields=True,
                preserving_proto_field_name=True)
        except InspectionFailed as e:
            rcode = HTTP_BAD_REQUEST
            response = {
                "error": "appointment rejected. {}".format(e.reason),
                "error_code": e.erno
            }

        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.UNAUTHENTICATED:
                rcode = HTTP_BAD_REQUEST
                response = {
                    "error":
                    f"appointment rejected. {e.details()}",
                    "error_code":
                    errors.APPOINTMENT_INVALID_SIGNATURE_OR_SUBSCRIPTION_ERROR,
                }
            elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
                rcode = HTTP_BAD_REQUEST
                response = {
                    "error": f"appointment rejected. {e.details()}",
                    "error_code": errors.APPOINTMENT_ALREADY_TRIGGERED,
                }
            else:
                # This covers grpc.StatusCode.RESOURCE_EXHAUSTED (and any other return).
                rcode = HTTP_SERVICE_UNAVAILABLE
                response = {"error": "appointment rejected"}

        self.logger.info("Sending response and disconnecting",
                         from_addr="{}".format(remote_addr),
                         response=response)
        return jsonify(response), rcode
예제 #21
0
def parse_and_upload(
    fs,
    dst_path_rt,
    tmp_dir,
    hour: RTHourlyAggregation,
    verbose=False,
    pbar=None,
) -> List[RTFileProcessingOutcome]:
    written = 0
    outcomes = []
    gzip_fname = str(tmp_dir + f"/data_{hour.suffix}" + JSONL_GZIP_EXTENSION)

    # ParseFromString() seems to not release memory well, so manually handle
    # writing to the gzip and cleaning up after ourselves

    with gzip.open(gzip_fname, "w") as gzipfile:
        for rt_file in hour.source_files:
            feed = gtfs_realtime_pb2.FeedMessage()

            try:
                with open(
                        os.path.join(dst_path_rt,
                                     rt_file.timestamped_filename), "rb") as f:
                    feed.ParseFromString(f.read())
                parsed = json_format.MessageToDict(feed)
            except DecodeError as e:
                if verbose:
                    log(
                        f"WARNING: DecodeError for {str(rt_file.path)}",
                        fg=typer.colors.YELLOW,
                        pbar=pbar,
                    )
                outcomes.append(
                    RTFileProcessingOutcome(
                        step="parse",
                        success=False,
                        exception=e,
                        file=rt_file,
                    ))
                continue

            if not parsed or "entity" not in parsed:
                msg = f"WARNING: no records found in {str(rt_file.path)}"
                if verbose:
                    log(
                        msg,
                        fg=typer.colors.YELLOW,
                        pbar=pbar,
                    )
                outcomes.append(
                    RTFileProcessingOutcome(
                        step="parse",
                        success=False,
                        exception=ValueError(msg),
                        file=rt_file,
                    ))
                continue

            for record in parsed["entity"]:
                gzipfile.write((
                    json.dumps({
                        "header": parsed["header"],
                        # back and forth so we use pydantic serialization
                        "metadata": json.loads(rt_file.json()),
                        **copy.deepcopy(record),
                    }) + "\n").encode("utf-8"))
                written += 1
            outcomes.append(
                RTFileProcessingOutcome(
                    step="parse",
                    success=True,
                    file=rt_file,
                    n_output_records=len(parsed["entity"]),
                    hive_path=hour.data_hive_path,
                ))
            del parsed

    if written:
        log(
            f"writing {written} lines to {hour.data_hive_path}",
            pbar=pbar,
        )
        put_with_retry(fs, gzip_fname, f"{hour.data_hive_path}")
    else:
        log(
            f"WARNING: no records at all for {hour.data_hive_path}",
            fg=typer.colors.YELLOW,
            pbar=pbar,
        )

    return outcomes
예제 #22
0
    def get_appointment(self):
        """
        Gives information about a given appointment state in the Watchtower.

        The information is requested by ``locator``.

        Returns:
            :obj:`str`: A json formatted dictionary containing information about the requested appointment.

            Returns not found if the user does not have the requested appointment or the locator is invalid.

            A ``status`` flag is added to the data provided by either the :obj:`Watcher <teos.watcher.Watcher>` or the
            :obj:`Responder <teos.responder.Responder>` that signals the status of the appointment.

            - Appointments held by the :obj:`Watcher <teos.watcher.Watcher>` are flagged as
              ``AppointmentStatus.BEING_WATCHED``.
            - Appointments held by the :obj:`Responder <teos.responder.Responder>` are flagged as
              ``AppointmentStatus.DISPUTE_RESPONDED``.
            - Unknown appointments are flagged as ``AppointmentStatus.NOT_FOUND``.
        """

        # Getting the real IP if the server is behind a reverse proxy
        remote_addr = get_remote_addr()

        # Check that data type and content are correct. Abort otherwise.
        try:
            request_data = get_request_data_json(request)

        except InvalidParameter as e:
            self.logger.info("Received invalid get_appointment request",
                             from_addr="{}".format(remote_addr))
            return jsonify({
                "error": str(e),
                "error_code": errors.INVALID_REQUEST_FORMAT
            }), HTTP_BAD_REQUEST

        locator = request_data.get("locator")

        try:
            self.inspector.check_locator(locator)
            self.logger.info("Received get_appointment request",
                             from_addr="{}".format(remote_addr),
                             locator=locator)

            r = self.stub.get_appointment(
                GetAppointmentRequest(locator=locator,
                                      signature=request_data.get("signature")))
            data = (r.appointment_data.appointment
                    if r.appointment_data.WhichOneof("appointment_data")
                    == "appointment" else r.appointment_data.tracker)

            rcode = HTTP_OK
            response = {
                "locator":
                locator,
                "status":
                r.status,
                "appointment":
                json_format.MessageToDict(data,
                                          including_default_value_fields=True,
                                          preserving_proto_field_name=True),
            }

        except (InspectionFailed, grpc.RpcError) as e:
            if isinstance(e, grpc.RpcError) and e.code(
            ) == grpc.StatusCode.UNAUTHENTICATED:
                rcode = HTTP_BAD_REQUEST
                response = {
                    "error":
                    e.details(),
                    "error_code":
                    errors.APPOINTMENT_INVALID_SIGNATURE_OR_SUBSCRIPTION_ERROR,
                }
            else:
                rcode = HTTP_NOT_FOUND
                response = {
                    "locator": locator,
                    "status": AppointmentStatus.NOT_FOUND
                }

        return jsonify(response), rcode
예제 #23
0
def struct2dict(struct: StructProto) -> Dict:
    """Unpacks `google.protobuf.Struct` message to Python dict value.
    """
    return json_format.MessageToDict(struct)
예제 #24
0
    def _GetV2AnalysisResultFromV1(self, request):
        """Constructs v2 analysis results based on v1 analysis.

    This is a temporary work around to make sure Findit's analysis results for
    chromium build failures are still available on SoM during v1 to v2
    migration.

    Args:
      request (findit_result.BuildFailureAnalysisRequest)

    Returns:
      [findit_result.BuildFailureAnalysisResponse] for results of a v1 analysis,
      otherwise return None.
    """
        if (request.build_alternative_id
                and request.build_alternative_id.project != 'chromium'):
            return None

        build = None
        if request.build_id:
            build = buildbucket_client.GetV2Build(
                request.build_id,
                fields=FieldMask(
                    paths=['id', 'number', 'builder', 'output.properties']))
        elif request.build_alternative_id:
            build = buildbucket_client.GetV2BuildByBuilderAndBuildNumber(
                request.build_alternative_id.project,
                request.build_alternative_id.bucket,
                request.build_alternative_id.builder,
                request.build_alternative_id.number,
                fields=FieldMask(
                    paths=['id', 'number', 'builder', 'output.properties']))

        if not build:
            logging.error('Failed to download build when requesting for %s',
                          request)
            return None

        if build.builder.project != 'chromium':
            return None

        properties = json_format.MessageToDict(build.output.properties)
        build_number = build.number
        master_name = properties.get('target_mastername',
                                     properties.get('mastername'))
        if not build_number or not master_name:
            logging.error('Missing master_name or build_number for build %d',
                          build.id)
            return None

        heuristic_analysis = WfAnalysis.Get(master_name, build.builder.builder,
                                            build_number)
        if not heuristic_analysis:
            return None

        results = []
        v1_build_request = _BuildFailure(builder_name=build.builder.builder,
                                         build_number=build_number)
        self._GenerateResultsForBuild(v1_build_request, heuristic_analysis,
                                      results, None)
        return self._GetV2ResultFromV1(request, results)
예제 #25
0
 def _dump():
     if args.json:
         json.dump(
             json_format.MessageToDict(ret,
                                       preserving_proto_field_name=True),
             args.json)
예제 #26
0
 def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
     logger.log(logging.DEBUG if log_level is None else log_level,
                'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc,
                req.__class__.__name__, json_format.MessageToDict(req),
                ', '.join({f'{k}={v}'
                           for k, v in call_kwargs.items()}))
예제 #27
0
    def test_affinity(self) -> None:  # pylint: disable=too-many-statements

        with self.subTest('00_create_health_check'):
            self.td.create_health_check()

        with self.subTest('01_create_backend_services'):
            self.td.create_backend_service(
                affinity_header=_TEST_AFFINITY_METADATA_KEY)

        with self.subTest('02_create_url_map'):
            self.td.create_url_map(self.server_xds_host, self.server_xds_port)

        with self.subTest('03_create_target_proxy'):
            self.td.create_target_proxy()

        with self.subTest('04_create_forwarding_rule'):
            self.td.create_forwarding_rule(self.server_xds_port)

        test_servers: List[_XdsTestServer]
        with self.subTest('05_start_test_servers'):
            test_servers = self.startTestServers(replica_count=_REPLICA_COUNT)

        with self.subTest('06_add_server_backends_to_backend_services'):
            self.setupServerBackends()

        test_client: _XdsTestClient
        with self.subTest('07_start_test_client'):
            test_client = self.startTestClient(test_servers[0],
                                               rpc='EmptyCall',
                                               metadata='EmptyCall:%s:123' %
                                               _TEST_AFFINITY_METADATA_KEY)
            # Validate the number of received endpoints and affinity configs.
            config = test_client.csds.fetch_client_status(
                log_level=logging.INFO)
            self.assertIsNotNone(config)
            json_config = json_format.MessageToDict(config)
            parsed = xds_url_map_testcase.DumpedXdsConfig(json_config)
            logging.info('Client received CSDS response: %s', parsed)
            self.assertLen(parsed.endpoints, _REPLICA_COUNT)
            self.assertEqual(
                parsed.rds['virtualHosts'][0]['routes'][0]['route']
                ['hashPolicy'][0]['header']['headerName'],
                _TEST_AFFINITY_METADATA_KEY)
            self.assertEqual(parsed.cds[0]['lbPolicy'], 'RING_HASH')

        with self.subTest('08_test_client_xds_config_exists'):
            self.assertXdsConfigExists(test_client)

        with self.subTest('09_test_server_received_rpcs_from_test_client'):
            self.assertSuccessfulRpcs(test_client)

        with self.subTest('10_first_100_affinity_rpcs_pick_same_backend'):
            rpc_stats = self.getClientRpcStats(test_client, _RPC_COUNT)
            json_lb_stats = json_format.MessageToDict(rpc_stats)
            rpc_distribution = xds_url_map_testcase.RpcDistributionStats(
                json_lb_stats)
            self.assertEqual(1, rpc_distribution.num_peers)
            self.assertLen(
                test_client.find_subchannels_with_state(
                    _ChannelzChannelState.READY),
                1,
            )
            self.assertLen(
                test_client.find_subchannels_with_state(
                    _ChannelzChannelState.IDLE),
                2,
            )
            # Remember the backend inuse, and turn it down later.
            first_backend_inuse = list(
                rpc_distribution.raw['rpcsByPeer'].keys())[0]

        with self.subTest('11_turn_down_server_in_use'):
            for s in test_servers:
                if s.pod_name == first_backend_inuse:
                    logging.info('setting backend %s to NOT_SERVING',
                                 s.pod_name)
                    s.set_not_serving()

        with self.subTest('12_wait_for_unhealth_status_propagation'):
            deadline = time.time() + _TD_PROPAGATE_TIMEOUT
            parsed = None
            try:
                while time.time() < deadline:
                    config = test_client.csds.fetch_client_status(
                        log_level=logging.INFO)
                    self.assertIsNotNone(config)
                    json_config = json_format.MessageToDict(config)
                    parsed = xds_url_map_testcase.DumpedXdsConfig(json_config)
                    if len(parsed.endpoints) == _REPLICA_COUNT - 1:
                        break
                    logging.info(
                        'CSDS got unexpected endpoints, will retry after %d seconds',
                        _TD_PROPAGATE_CHECK_INTERVAL_SEC)
                    time.sleep(_TD_PROPAGATE_CHECK_INTERVAL_SEC)
                else:
                    self.fail(
                        'unhealthy status did not propagate after 600 seconds')
            finally:
                logging.info('Client received CSDS response: %s', parsed)

        with self.subTest('12_next_100_affinity_rpcs_pick_different_backend'):
            rpc_stats = self.getClientRpcStats(test_client, _RPC_COUNT)
            json_lb_stats = json_format.MessageToDict(rpc_stats)
            rpc_distribution = xds_url_map_testcase.RpcDistributionStats(
                json_lb_stats)
            self.assertEqual(1, rpc_distribution.num_peers)
            new_backend_inuse = list(
                rpc_distribution.raw['rpcsByPeer'].keys())[0]
            self.assertNotEqual(new_backend_inuse, first_backend_inuse)
예제 #28
0
def _convert_proto_map_to_dict(proto_map: Any) -> Dict[Text, Dict[Text, Any]]:
    """Converts a metric map (metrics in MetricsForSlice protobuf) into a dict.

  Args:
    proto_map: A protocol buffer MessageMap that has behaviors like dict. The
      keys are strings while the values are protocol buffers. However, it is not
      a protobuf message and cannot be passed into json_format.MessageToDict
      directly. Instead, we must iterate over its values.

  Returns:
    A dict representing the proto_map. For example:
    Assume myProto contains
    {
      metrics: {
        key: 'double'
        value: {
          double_value: {
            value: 1.0
          }
        }
      }
      metrics: {
        key: 'bounded'
        value: {
          bounded_value: {
            lower_bound: {
              double_value: {
                value: 0.8
              }
            }
            upper_bound: {
              double_value: {
                value: 0.9
              }
            }
            value: {
              double_value: {
                value: 0.86
              }
            }
          }
        }
      }
    }

    The output of _convert_proto_map_to_dict(myProto.metrics) would be

    {
      'double': {
        'doubleValue': 1.0,
      },
      'bounded': {
        'boundedValue': {
          'lowerBound': 0.8,
          'upperBound': 0.9,
          'value': 0.86,
        },
      },
    }

    Note that field names are converted to lowerCamelCase and the field value in
    google.protobuf.DoubleValue is collapsed automatically.
  """
    return {k: json_format.MessageToDict(proto_map[k]) for k in proto_map}
    def _cancel_operation(
            self,
            request: operations_pb2.CancelOperationRequest,
            *,
            retry: OptionalRetry = gapic_v1.method.DEFAULT,
            timeout: Optional[float] = None,
            metadata: Sequence[Tuple[str, str]] = (),
    ) -> empty_pb2.Empty:
        r"""Call the cancel operation method over HTTP.

        Args:
            request (~.operations_pb2.CancelOperationRequest):
                The request object. The request message for
                [Operations.CancelOperation][google.api_core.operations_v1.Operations.CancelOperation].

            retry (google.api_core.retry.Retry): Designation of what errors, if any,
                should be retried.
            timeout (float): The timeout for this request.
            metadata (Sequence[Tuple[str, str]]): Strings which should be
                sent along with the request as metadata.
        """

        http_options = [
            {
                "method": "post",
                "uri": "/v1/{name=operations/**}:cancel",
                "body": "*"
            },
        ]
        if "google.longrunning.Operations.CancelOperation" in self._http_options:
            http_options = self._http_options[
                "google.longrunning.Operations.CancelOperation"]

        request_kwargs = json_format.MessageToDict(
            request,
            preserving_proto_field_name=True,
            including_default_value_fields=True,
        )
        transcoded_request = path_template.transcode(http_options,
                                                     **request_kwargs)

        # Jsonify the request body
        body_request = operations_pb2.CancelOperationRequest()
        json_format.ParseDict(transcoded_request["body"], body_request)
        body = json_format.MessageToDict(
            body_request,
            including_default_value_fields=False,
            preserving_proto_field_name=False,
            use_integers_for_enums=False,
        )
        uri = transcoded_request["uri"]
        method = transcoded_request["method"]

        # Jsonify the query params
        query_params_request = operations_pb2.CancelOperationRequest()
        json_format.ParseDict(transcoded_request["query_params"],
                              query_params_request)
        query_params = json_format.MessageToDict(
            query_params_request,
            including_default_value_fields=False,
            preserving_proto_field_name=False,
            use_integers_for_enums=False,
        )

        # Send the request
        headers = dict(metadata)
        headers["Content-Type"] = "application/json"
        response = getattr(self._session, method)(
            "https://{host}{uri}".format(host=self._host, uri=uri),
            timeout=timeout,
            headers=headers,
            params=rest_helpers.flatten_query_params(query_params),
            data=body,
        )

        # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception
        # subclass.
        if response.status_code >= 400:
            raise core_exceptions.from_http_response(response)

        return empty_pb2.Empty()
예제 #30
0
def bucket_get_iam_policy(bucket_name):
    db.insert_test_bucket(None)
    bucket = db.get_bucket(flask.request, bucket_name, None)
    response = json_format.MessageToDict(bucket.iam_policy)
    response["kind"] = "storage#policy"
    return response