コード例 #1
0
ファイル: _torch.py プロジェクト: yubo1993/FATE
    def export_model(self, param):

        param_pb = nn_model_param_pb2.NNModelParam()

        # save api_version
        param_pb.api_version = param.api_version

        # save pl model bytes
        with tempfile.TemporaryDirectory() as d:
            filepath = os.path.join(d, "model.ckpt")
            self.pl_trainer.save_checkpoint(filepath)
            with open(filepath, "rb") as f:
                param_pb.saved_model_bytes = f.read()

        # save header
        param_pb.header.extend(self.header)

        # save label mapping
        if self.label_mapping is not None:
            for label, mapped in self.label_mapping.items():
                param_pb.label_mapping.add(label=json.dumps(label),
                                           mapped=json.dumps(mapped))

        # meta
        meta_pb = nn_model_meta_pb2.NNModelMeta()
        meta_pb.params.CopyFrom(param.generate_pb())
        meta_pb.aggregate_iter = self.context.aggregation_iteration

        return {
            _consts.MODEL_META_NAME: meta_pb,
            _consts.MODEL_PARAM_NAME: param_pb
        }
コード例 #2
0
ファイル: enter_point.py プロジェクト: zpskt/FATE
 def _get_param(self):
     from federatedml.protobuf.generated import nn_model_param_pb2
     param_pb = nn_model_param_pb2.NNModelParam()
     param_pb.saved_model_bytes = self.nn_model.export_model()
     param_pb.header.extend(self._header)
     for label, mapped in self._label_align_mapping.items():
         param_pb.label_mapping.add(label=json.dumps(label), mapped=json.dumps(mapped))
     return param_pb
コード例 #3
0
ファイル: _torch.py プロジェクト: yubo1993/FATE
    def export_model(self, param):

        param_pb = nn_model_param_pb2.NNModelParam()

        # save api_version
        param_pb.api_version = param.api_version

        meta_pb = nn_model_meta_pb2.NNModelMeta()
        meta_pb.params.CopyFrom(param.generate_pb())
        meta_pb.aggregate_iter = self.context.aggregation_iteration

        return {
            _consts.MODEL_META_NAME: meta_pb,
            _consts.MODEL_PARAM_NAME: param_pb
        }
コード例 #4
0
ファイル: enter_point.py プロジェクト: zhxuan300/FATE
 def _get_param(self):
     from federatedml.protobuf.generated import nn_model_param_pb2
     param_pb = nn_model_param_pb2.NNModelParam()
     param_pb.saved_model_bytes = self.nn_model.export_model()
     return param_pb
コード例 #5
0
ファイル: _version_0.py プロジェクト: yubo1993/FATE
def arbiter_get_param(self):
    from federatedml.protobuf.generated import nn_model_param_pb2

    param_pb = nn_model_param_pb2.NNModelParam()
    return param_pb