def read_tfrecords(tfrecords): tfrecords_bytes = io.BytesIO(tfrecords) examples = [] while True: length_header = 12 buf = tfrecords_bytes.read(length_header) if not buf: # reached end of tfrecord buffer, return examples return examples if len(buf) != length_header: raise ValueError('TFrecord is fewer than %d bytes' % length_header) length, length_mask = struct.unpack('<QI', buf) length_mask_actual = _masked_crc32c(buf[:8]) if length_mask_actual != length_mask: raise ValueError('TFRecord does not contain a valid length mask') length_data = length + 4 buf = tfrecords_bytes.read(length_data) if len(buf) != length_data: raise ValueError('TFRecord data payload has fewer bytes than specified in header') data, data_mask_expected = struct.unpack('<%dsI' % length, buf) data_mask_actual = _masked_crc32c(data) if data_mask_actual != data_mask_expected: raise ValueError('TFRecord has an invalid data crc32c') # Deserialize the tf.Example proto example = tf.train.Example() example.ParseFromString(data) # Extract a feature map from the example object example_feature = MessageToDict(example.features)['feature'] feature_dict = {} for feature_key in example_feature.keys(): feature_dict[feature_key] = example_feature[feature_key][list(example_feature[feature_key].keys())[0]]['value'][0] examples.append(feature_dict)
def metric_update(self, project, metric_name, filter_, description): """API call: update a metric resource. :type project: str :param project: ID of the project containing the metric. :type metric_name: str :param metric_name: the name of the metric :type filter_: str :param filter_: the advanced logs filter expression defining the entries exported by the metric. :type description: str :param description: description of the metric. :rtype: dict :returns: The metric object returned from the API (converted from a protobuf to a dictionary). """ options = None path = 'projects/%s/metrics/%s' % (project, metric_name) metric_pb = LogMetric(name=path, filter=filter_, description=description) try: metric_pb = self._gax_api.update_log_metric(path, metric_pb, options=options) except GaxError as exc: if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: raise NotFound(path) raise # NOTE: LogMetric message type does not have an ``Any`` field # so `MessageToDict`` can safely be used. return MessageToDict(metric_pb)
def transcribe_gcs(audio_file_path): bucket_name = 'european-germany-bucket' # Your gcloud bucket name print(audio_file_path) audio_file_name = osp.basename(audio_file_path) print(audio_file_name) # todo: do checking if it's already uploaded or not and upload it only if it's missing # upload_to_gcloud(bucket_name, source_file_name=audio_file_path, destination_blob_name=audio_file_name) """Asynchronously transcribes the audio file specified by the gcs_uri.""" client = speech.SpeechClient() audio = types.RecognitionAudio(uri="gs://" + bucket_name + "/" + audio_file_name) config = types.RecognitionConfig( encoding=enums.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED, language_code='cs-CZ', # sample_rate_hertz=16000, enable_word_time_offsets=True) operation = client.long_running_recognize(config, audio) while not operation.done(): print('Waiting for results...') time.sleep(30) # 30 seconds result = operation.result() results = result.results with open(audio_file_path + '.json', 'w', encoding='utf-8') as f: result_dict = MessageToDict(result) json.dump(result_dict, f, indent=True) with open(audio_file_path + '.txt', 'w', encoding='utf-8') as raw_text_file: for result in results: for alternative in result.alternatives: raw_text_file.write(alternative.transcript + '\n')
def get_product( self, location: str, product_id: str, project_id: str = PROVIDE_PROJECT_ID, retry: Union[Retry, _MethodDefault] = DEFAULT, timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ): """ For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator` """ client = self.get_conn() name = ProductSearchClient.product_path(project_id, location, product_id) self.log.info('Retrieving Product: %s', name) response = client.get_product(name=name, retry=retry, timeout=timeout, metadata=metadata) self.log.info('Product retrieved.') self.log.debug('Product retrieved:\n%s', response) return MessageToDict(response)
def test_update_storage(self, mock_parse_request, *args): params = { 'storage_id': utils.generate_id('storage'), 'name': 'update-storage-name', 'tags': { 'update_key': 'update_value' }, 'domain_id': utils.generate_id('domain') } mock_parse_request.return_value = (params, {}) storage_servicer = Storage() storage_info = storage_servicer.update(params, {}) print_message(storage_info, 'test_update_schedule') storage_data = MessageToDict(storage_info, preserving_proto_field_name=True) self.assertIsInstance(storage_info, storage_pb2.StorageInfo) self.assertEqual(storage_data['name'], params['name']) self.assertEqual(storage_data['storage_id'], params['storage_id']) self.assertDictEqual(storage_data['tags'], params['tags']) print(f'[TEST UPDATE STORAGE] {storage_data}')
def get_account_summary_and_this_device(params: KeeperParams): def to_alphanumerics(text): # remove ALL non - alphanumerics return re.sub(r'[\W_]+', '', text) def compare_device_tokens(t1: str, t2: str): t1 = to_alphanumerics(t1) t2 = to_alphanumerics(t2) return t1 == t2 acct_summary = loginv3.LoginV3API.accountSummary(params) acct_summary_dict = MessageToDict(acct_summary) devices = acct_summary_dict['devices'] if 'device_token' not in params.config: current_device_token = rest_api.get_device_token(params.rest_context) else: current_device_token = params.config['device_token'] this_device = next((item for item in devices if compare_device_tokens(item['encryptedDeviceToken'], current_device_token)), None) return acct_summary_dict, this_device
def predict(): res = {"message": "", "results": []} if request.json: req_dict = request.get_json() try: # convert image to protobuffer image = image_preprocess(req_dict) except Exception as e: current_app.logger.error(f'pre handler image error: {str(e)}') res['message'] = str(e) return jsonify(res) # put to predict try: remote_results = app.config['predict'].Predict(image) res['results'] = MessageToDict(remote_results)['results'] except Exception as e: current_app.logger.error(e.details) res['message'] = f"inference failed: {e.code()}" return jsonify(res) else: res['message'] = 'please post JSON format data' return jsonify(res)
def DescribeDeployment(self, request, context=None): try: request.namespace = request.namespace or self.default_namespace deployment_pb = self.deployment_store.get(request.deployment_name, request.namespace) if deployment_pb: operator = get_deployment_operator(deployment_pb) response = operator.describe(deployment_pb, self.repo) if response.status.status_code == status_pb2.Status.OK: with self.deployment_store.update_deployment( request.deployment_name, request.namespace) as deployment: deployment.state = MessageToDict(response.state) return response else: return DescribeDeploymentResponse(status=Status.NOT_FOUND( 'Deployment "{}" in namespace "{}" not found'.format( request.deployment_name, request.namespace))) except BentoMLException as e: logger.error("INTERNAL ERROR: %s", e) return DescribeDeploymentResponse(Status.INTERNAL(e))
def update_product( self, product: Union[dict, Product], project_id: str, location: Optional[str] = None, product_id: Optional[str] = None, update_mask: Optional[Dict[str, FieldMask]] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, ): """ For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator` """ client = self.get_conn() product = self.product_name_determiner.get_entity_with_name(product, product_id, location, project_id) self.log.info('Updating ProductSet: %s', product.name) response = client.update_product( product=product, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) self.log.info('Product updated: %s', response.name if response else '') self.log.debug('Product updated:\n%s', response) return MessageToDict(response)
def send_request(server_url, model_name, signature_name, input_name, image_paths): imgs = load_imgs(image_paths) # create the gRPC stub options = [('grpc.max_message_length', 100 * 1024 * 1024)] channel = grpc.insecure_channel(server_url, options=options) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) # create the request object and set the name and signature_name params request = predict_pb2.PredictRequest() request.model_spec.name = model_name request.model_spec.signature_name = signature_name # fill in the request object with the necessary data request.inputs[input_name].CopyFrom( tf.contrib.util.make_tensor_proto(imgs, shape=imgs.shape)) result = stub.Predict(request, 200000.) output_dict = MessageToDict(result, preserving_proto_field_name=True) # print(output_dict.keys()) # print(output_dict['outputs'].keys()) # print(output_dict['outputs']['num_detections'].keys()) return output_dict['outputs']
def analyze_labels_file(self, file): """Detect labels given a file path.""" video_client = videointelligence.VideoIntelligenceServiceClient() features = [videointelligence.enums.Feature.LABEL_DETECTION] # read the file object input_content = file.read() operation = video_client.annotate_video(features=features, input_content=input_content) print("\nProcessing video for label annotations:") result = operation.result(timeout=90) print("\nFinished processing.") # Process video/segment level label annotations segment_labels = result.annotation_results[0].segment_label_annotations for i, segment_label in enumerate(segment_labels): print("Video label description: {}".format( segment_label.entity.description)) for category_entity in segment_label.category_entities: print("\tLabel category description: {}".format( category_entity.description)) for i, segment in enumerate(segment_label.segments): start_time = (segment.segment.start_time_offset.seconds + segment.segment.start_time_offset.nanos / 1e9) end_time = (segment.segment.end_time_offset.seconds + segment.segment.end_time_offset.nanos / 1e9) positions = "{}s to {}s".format(start_time, end_time) confidence = segment.confidence print("\tSegment {}: {}".format(i, positions)) print("\tConfidence: {}".format(confidence)) print("\n") return MessageToDict(result, preserving_proto_field_name=True)
def simulate_eapol_flow_install(self, ldev_id, olt_id, onu_ids): # emulate the flow mod requests that shall arrive from the SDN # controller, one for each ONU lports = self.get( '/api/v1/logical_devices/{}/ports'.format(ldev_id) )['items'] # device_id -> logical port map, which we will use to construct # our flows lport_map = dict((lp['device_id'], lp) for lp in lports) for onu_id in onu_ids: # if eth_type == 0x888e => send to controller _in_port = lport_map[onu_id]['ofp_port']['port_no'] req = ofp.FlowTableUpdate( id='ponsim1', flow_mod=mk_simple_flow_mod( match_fields=[ in_port(_in_port), vlan_vid(ofp.OFPVID_PRESENT | 0), eth_type(0x888e)], actions=[ output(ofp.OFPP_CONTROLLER) ], priority=1000 ) ) res = self.post('/api/v1/logical_devices/{}/flows'.format(ldev_id), MessageToDict(req, preserving_proto_field_name=True), expected_code=200) # for sanity, verify that flows are in flow table of logical device flows = self.get( '/api/v1/logical_devices/{}/flows'.format(ldev_id))['items'] self.assertGreaterEqual(len(flows), 4)
def save(self, request, context): try: auth_token = parser_context(context, 'auth_token') is_auth(auth_token, '03_state_save') state_object = MessageToDict(request) country = Countries.objects.get(id=state_object['country']) del state_object['country'] state = States(**state_object) state.country = country state.save() country.update(push__states=state) state = parser_one_object(state) response = state_pb2.StateResponse(state=state) return response except NotUniqueError as e: exist_code(context, e)
def json_snapshot_message(snapshot, user_id, color_image_path, depth_image_path): """Convert protocol-snapshot to json, replacing binary data with path to data. :param snapshot: snapshot object :type snapshot: Snapshot :param user_id: user id corresponding the snapshot :type user_id: int :param image_path: path to BLOBS :type image_path: str :returns: snapshot information json :rtype: json """ snapshot_metadata = Snapshot() snapshot_metadata.CopyFrom(snapshot) snapshot_metadata.color_image.ClearField('data') snapshot_metadata.depth_image.ClearField('data') snapshot_dict = MessageToDict(snapshot_metadata, preserving_proto_field_name=True, including_default_value_fields=True) snapshot_dict['user_id'] = user_id snapshot_dict['color_image']['data'] = color_image_path snapshot_dict['depth_image']['data'] = depth_image_path return json.dumps(snapshot_dict)
def test_update_role_policies(self): """ Update Role Policies """ self.test_create_role() self._test_create_policy(['identity.*']) update_policies = list( map( lambda policy: { 'policy_id': policy.policy_id, 'policy_type': 'CUSTOM' }, self.policies)) self.role = self.identity_v1.Role.update( { 'role_id': self.role.role_id, 'policies': update_policies, 'domain_id': self.domain.domain_id }, metadata=(('token', self.owner_token), )) self._print_data(self.policy, 'test_update_role_policies') role_info = MessageToDict(self.role, preserving_proto_field_name=True) self.assertEqual(role_info['policies'], update_policies)
async def get_online_features(request: Request): try: # Validate and parse the request data into GetOnlineFeaturesRequest Protobuf object body = await request.body() request_proto = GetOnlineFeaturesRequest() Parse(body, request_proto) # Initialize parameters for FeatureStore.get_online_features(...) call if request_proto.HasField("feature_service"): features = store.get_feature_service(request_proto.feature_service) else: features = list(request_proto.features.val) full_feature_names = request_proto.full_feature_names batch_sizes = [len(v.val) for v in request_proto.entities.values()] num_entities = batch_sizes[0] if any(batch_size != num_entities for batch_size in batch_sizes): raise HTTPException(status_code=500, detail="Uneven number of columns") entity_rows = [ {k: v.val[idx] for k, v in request_proto.entities.items()} for idx in range(num_entities) ] response_proto = store.get_online_features( features, entity_rows, full_feature_names=full_feature_names ).proto # Convert the Protobuf object to JSON and return it return MessageToDict(response_proto, preserving_proto_field_name=True) except Exception as e: # Print the original exception on the server side logger.exception(e) # Raise HTTPException to return the error message to the client raise HTTPException(status_code=500, detail=str(e))
def update_product( self, product, location=None, product_id=None, update_mask=None, project_id=None, retry=None, timeout=None, metadata=None, ): """ For the documentation see: :class:`~airflow.contrib.operators.gcp_vision_operator.CloudVisionProductUpdateOperator` """ client = self.get_conn() product = self.product_name_determiner.get_entity_with_name(product, product_id, location, project_id) self.log.info('Updating ProductSet: %s', product.name) response = client.update_product( product=product, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) self.log.info('Product updated: %s', response.name if response else '') self.log.debug('Product updated:\n%s', response) return MessageToDict(response)
def SetPhase(self, request, context): """Update the job phase.""" decoded_token, msg = self._decode_build_token( request.job_jwt, BUILD_JOB_TOKEN_TYPE, ) if not decoded_token: self._handle_error(context, grpc.StatusCode.UNAUTHENTICATED, msg) return buildman_pb2.SetPhaseResponse() job_id = decoded_token["job_id"] phase_metadata = {} if request.HasField("pull_metadata"): phase_metadata.update( MessageToDict(request.pull_metadata, preserving_proto_field_name=True)) updated = self._lifecycle_manager.update_job_phase( job_id, self.GRPC_PHASE_TO_BUILD_PHASE[request.phase], phase_metadata, ) return buildman_pb2.SetPhaseResponse( success=updated, sequence_number=request.sequence_number)
def test_list_and_update_local_logical_device_flows(self): # retrieve flow list res = self.get('/api/v1/local/logical_devices/simulated1/flows') len_before = len(res['items']) t0 = time() # add some flows for _ in xrange(10): req = ofp.FlowTableUpdate( id='simulated1', flow_mod=mk_simple_flow_mod( cookie=randint(1, 10000000000), priority=randint(1, 10000), # to make it unique match_fields=[in_port(129)], actions=[output(1)])) self.post('/api/v1/local/logical_devices/simulated1/flows', MessageToDict(req, preserving_proto_field_name=True), expected_code=200) print time() - t0 res = self.get('/api/v1/local/logical_devices/simulated1/flows') len_after = len(res['items']) self.assertGreater(len_after, len_before)
def tfdv_detect_drift( stats_older_path: str, stats_new_path: str ) -> NamedTuple('Outputs', [('drift', str)]): import logging import time import tensorflow_data_validation as tfdv import tensorflow_data_validation.statistics.stats_impl logging.getLogger().setLevel(logging.INFO) logging.info('stats_older_path: %s', stats_older_path) logging.info('stats_new_path: %s', stats_new_path) if stats_older_path == 'none': return ('true', ) stats1 = tfdv.load_statistics(stats_older_path) stats2 = tfdv.load_statistics(stats_new_path) schema1 = tfdv.infer_schema(statistics=stats1) tfdv.get_feature(schema1, 'duration').drift_comparator.jensen_shannon_divergence.threshold = 0.01 drift_anomalies = tfdv.validate_statistics( statistics=stats2, schema=schema1, previous_statistics=stats1) logging.info('drift analysis results: %s', drift_anomalies.drift_skew_info) from google.protobuf.json_format import MessageToDict d = MessageToDict(drift_anomalies) val = d['driftSkewInfo'][0]['driftMeasurements'][0]['value'] thresh = d['driftSkewInfo'][0]['driftMeasurements'][0]['threshold'] logging.info('value %s and threshold %s', val, thresh) res = 'true' if val < thresh: res = 'false' logging.info('train decision: %s', res) return (res, )
def gcloud_clusters_describe_command(client: ClusterManagerClient, project: str = "", cluster: str = "", zone: str = "") -> COMMAND_OUTPUT: """ Gets the details of a specific cluster. https://cloud.google.com/sdk/gcloud/reference/container/clusters/describe Args: client: Google container client. project: GCP project from console. cluster: Cluster ID, e.g. "dmst-gcloud-cluster-1". zone: Project query zone, e.g. "europe-west2-a". Returns: str: Human readable. dict: Cluster entry context. dict: Cluster raw response. """ # Query and gPRC unpack raw_response_msg: Message = client.get_cluster(cluster_id=cluster, project_id=project, zone=zone, timeout=API_TIMEOUT) # Entry context raw_response_dict: dict = MessageToDict(raw_response_msg) cluster_ec = parse_cluster(raw_response_dict) entry_context = { CLUSTER_CONTEXT: cluster_ec, } # Human readable human_readable: str = tableToMarkdown( t=parse_cluster_table(cluster_ec), name=f'Clusters (Project={project}, Zone={zone}, Cluster={cluster})', ) return human_readable, entry_context, raw_response_dict
def _protobuf_to_adsmsg_citation_change(pure_protobuf): """ Transforms pure citation_change protobuf to adsmsg.CitationChange, which can be safely sent via Celery/RabbitMQ. """ tmp = MessageToDict(pure_protobuf, preserving_proto_field_name=True) if 'content_type' in tmp: # Convert content_type from string to value tmp['content_type'] = getattr(adsmsg.CitationChangeContentType, tmp['content_type']) else: tmp['content_type'] = 0 # default: adsmsg.CitationChangeContentType.doi recover_timestamp = False if 'timestamp' in tmp: # Ignore timestamp in string format # 'timestamp': '2019-01-03T21:00:02.010610Z' del tmp['timestamp'] recover_timestamp = True citation_change = adsmsg.CitationChange(**tmp) if recover_timestamp: # Recover timestamp in google.protobuf.timestamp_pb2.Timestamp format # 'timestamp': seconds: 1546549202 nanos: 10610000 citation_change.timestamp = pure_protobuf.timestamp return citation_change
def gcloud_operations_list_command(client: ClusterManagerClient, project: str, zone: str) -> COMMAND_OUTPUT: """ List operations in project-zone. https://cloud.google.com/sdk/gcloud/reference/container/operations/list Args: client: Google container client. project: GCP project from console. zone: Project query zone, e.g. "europe-west2-a". Returns: str: Human readable. dict: Operation entry context. dict: Operation raw response. """ # Query operation status raw_response_msg: Message = client.list_operations(project_id=project, zone=zone, timeout=API_TIMEOUT) raw_response_dict: dict = MessageToDict(raw_response_msg) # Entry context operations: List[dict] = [ parse_operation(operation) for operation in raw_response_dict.get('operations', []) ] entry_context = { OPERATION_CONTEXT: operations, } # Human readable human_readable: str = tableToMarkdown( t=operations, headers=OPERATION_TABLE, name=f'Project {project} - Zone {zone} - Operations') return human_readable, entry_context, raw_response_dict
def to_dict(self: db.Model): """A helper function to convert a sqlalchemy model to dict.""" dic = {} # Puts all columns into the dict for col in self.__table__.columns: if col.key in ignores: continue dic[col.key] = getattr(self, col.key) # Puts extra items specified by consumer for extra_key, func in extras.items(): dic[extra_key] = func(self) # Converts type for key in dic: value = dic[key] if isinstance(value, datetime): dic[key] = int(value.timestamp()) elif isinstance(value, Message): dic[key] = MessageToDict( value, preserving_proto_field_name=True, including_default_value_fields=True) elif isinstance(value, Enum): dic[key] = value.name return dic
def call_google_ocr_api(base64_image_id): json_file_path = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS_OCR') if json_file_path is None: raise Exception( "please set GOOGLE_APPLICATION_CREDENTIALS_OCR variable to your google key path" ) service_account_info = json.load(open(json_file_path)) credentials = service_account.Credentials.from_service_account_info( service_account_info) client = vision.ImageAnnotatorClient(credentials=credentials) content = base64_image_id content = base64.b64decode(content) # send base64 of image in content image = types.Image(content=content) # Performs label detection on the image file response = client.text_detection(image) # print(response) # Convert the response to dictionary response = MessageToDict(response) # Convert to Json res_json = json.dumps(response) # print(res_json) return res_json
def metric_get(self, project, metric_name): """API call: retrieve a metric resource. :type project: str :param project: ID of the project containing the metric. :type metric_name: str :param metric_name: the name of the metric :rtype: dict :returns: The metric object returned from the API (converted from a protobuf to a dictionary). """ options = None path = 'projects/%s/metrics/%s' % (project, metric_name) try: metric_pb = self._gax_api.get_log_metric(path, options=options) except GaxError as exc: if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: raise NotFound(path) raise # NOTE: LogMetric message type does not have an ``Any`` field # so `MessageToDict`` can safely be used. return MessageToDict(metric_pb)
def test_update_ip_data(self): self._create_subnet() ip = '172.16.1.10' self.test_allocate_ip(subnet_id=self.subnet.subnet_id, ip=ip) data = { 'xxxxx': 'bbbb', 'yyyyyy': 'zzzz', 'aaaaa': { 'bbbbb': 'cccccc' } } params = { 'ip_address': ip, 'subnet_id': self.subnet.subnet_id, 'domain_id': self.domain.domain_id, 'data': data } self.ip = self.inventory_v1.IPAddress.update(params, metadata=(('token', self.token),)) self.assertEqual(MessageToDict(self.ip.data), data)
def get_product(self, location: str, product_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None): """ For the documentation see: :class:`~airflow.contrib.operators.gcp_vision_operator.CloudVisionProductGetOperator` """ if not project_id: raise ValueError("Project ID should be set.") client = self.get_conn() name = ProductSearchClient.product_path(project_id, location, product_id) self.log.info('Retrieving Product: %s', name) response = client.get_product(name=name, retry=retry, timeout=timeout, metadata=metadata) self.log.info('Product retrieved.') self.log.debug('Product retrieved:\n%s', response) return MessageToDict(response)
def sink_update(self, project, sink_name, filter_, destination, unique_writer_identity=False): """API call: update a sink resource. :type project: str :param project: ID of the project containing the sink. :type sink_name: str :param sink_name: the name of the sink :type filter_: str :param filter_: the advanced logs filter expression defining the entries exported by the sink. :type destination: str :param destination: destination URI for the entries exported by the sink. :type unique_writer_identity: bool :param unique_writer_identity: (Optional) determines the kind of IAM identity returned as writer_identity in the new sink. :rtype: dict :returns: The sink resource returned from the API (converted from a protobuf to a dictionary). """ path = 'projects/%s/sinks/%s' % (project, sink_name) sink_pb = LogSink(name=path, filter=filter_, destination=destination) sink_pb = self._gapic_api.update_sink( path, sink_pb, unique_writer_identity=unique_writer_identity) # NOTE: LogSink message type does not have an ``Any`` field # so `MessageToDict`` can safely be used. return MessageToDict(sink_pb)
def test_creds_snippet(self): m = {} for k, comp in self.components.items(): name, suff = k.split('.') with self.subTest(f'get snippet for {name}'): # print(f'before: {comp.rd}') try: for transform in comp.transforms: if isinstance(transform, transforms.Creds): m = {**m, **transform.snippet()} except Exception as e: raise Exception( f'Failure to get snippet for {comp}') from e for k, v in self.creds.items(): self.assertIn(k, m, f'{k} not found, expecting for {v}') self.assertEqual(v.credential_id, m[k].credential_id, f'Credential Id mismatch for {v} and {m[k]}') self.assertEqual(v.credential_type, m[k].credential_type, f'Credential type mismatch for {v} and {m[k]}') dump = credentials.dump_credentials(m) self.assertListEqual( yaml.load(dump, Loader=yaml.SafeLoader)['credentials'], sorted([MessageToDict(v.proto) for v in m.values()], key=lambda v: v['id']['value']))
} ], "volumes": [ { "persistentVolumeClaim": { "claimName": "pvc-fedlearner-default" }, "name": "data" } ] } }, "pair": true, "replicas": 1 } } } } ''') ]) return workflow if __name__ == '__main__': print( json.dumps( MessageToDict(make_workflow_template(), preserving_proto_field_name=True, including_default_value_fields=True)))