def export_model(self, param): param_pb = nn_model_param_pb2.NNModelParam() # save api_version param_pb.api_version = param.api_version # save pl model bytes with tempfile.TemporaryDirectory() as d: filepath = os.path.join(d, "model.ckpt") self.pl_trainer.save_checkpoint(filepath) with open(filepath, "rb") as f: param_pb.saved_model_bytes = f.read() # save header param_pb.header.extend(self.header) # save label mapping if self.label_mapping is not None: for label, mapped in self.label_mapping.items(): param_pb.label_mapping.add(label=json.dumps(label), mapped=json.dumps(mapped)) # meta meta_pb = nn_model_meta_pb2.NNModelMeta() meta_pb.params.CopyFrom(param.generate_pb()) meta_pb.aggregate_iter = self.context.aggregation_iteration return { _consts.MODEL_META_NAME: meta_pb, _consts.MODEL_PARAM_NAME: param_pb }
def _get_param(self): from federatedml.protobuf.generated import nn_model_param_pb2 param_pb = nn_model_param_pb2.NNModelParam() param_pb.saved_model_bytes = self.nn_model.export_model() param_pb.header.extend(self._header) for label, mapped in self._label_align_mapping.items(): param_pb.label_mapping.add(label=json.dumps(label), mapped=json.dumps(mapped)) return param_pb
def export_model(self, param): param_pb = nn_model_param_pb2.NNModelParam() # save api_version param_pb.api_version = param.api_version meta_pb = nn_model_meta_pb2.NNModelMeta() meta_pb.params.CopyFrom(param.generate_pb()) meta_pb.aggregate_iter = self.context.aggregation_iteration return { _consts.MODEL_META_NAME: meta_pb, _consts.MODEL_PARAM_NAME: param_pb }
def _get_param(self): from federatedml.protobuf.generated import nn_model_param_pb2 param_pb = nn_model_param_pb2.NNModelParam() param_pb.saved_model_bytes = self.nn_model.export_model() return param_pb
def arbiter_get_param(self): from federatedml.protobuf.generated import nn_model_param_pb2 param_pb = nn_model_param_pb2.NNModelParam() return param_pb