def _initialize_model(self, config): logistic_param = LogisticParam() local_model_param = LocalModelParam() self.logistic_param = ParamExtract.parse_param_from_config(logistic_param, config) local_model_param = ParamExtract.parse_param_from_config(local_model_param, config) self.local_model = self._create_local_model(local_model_param) self.model = HeteroDNNLRHost(self.local_model, self.logistic_param) self.model.set_data_shape(local_model_param.encode_dim)
def _initialize_model(self, config): logistic_param = LogisticParam() local_model_param = LocalModelParam() self.logistic_param = ParamExtract.parse_param_from_config( logistic_param, config) local_model_param = ParamExtract.parse_param_from_config( local_model_param, config) self.local_model = self._create_local_model(local_model_param) self.model = HeteroDNNLRGuest(self.local_model, self.logistic_param) self.model.set_feature_shape(local_model_param.encode_dim) self.model.set_header(self._create_header( local_model_param.encode_dim)) self.model.set_local_model_update_proxy( SemiEncryptedLocalModelUpdateProxy())
def _initialize_model(self, config): LOGGER.debug("@ initialize model") ftl_model_param = FTLModelParam() ftl_local_model_param = FTLLocalModelParam() ftl_data_param = FTLDataParam() ftl_model_param = ParamExtract.parse_param_from_config( ftl_model_param, config) ftl_local_model_param = ParamExtract.parse_param_from_config( ftl_local_model_param, config) self.ftl_data_param = ParamExtract.parse_param_from_config( ftl_data_param, config) self.ftl_transfer_variable = HeteroFTLTransferVariable() self._do_initialize_model(ftl_model_param, ftl_local_model_param, ftl_data_param)
def _initialize_model(self, runtime_conf_path): feature_param = FeatureSelectionParam() self.feature_param = ParamExtract.parse_param_from_config( feature_param, runtime_conf_path) FeatureSelectionParamChecker.check_param(self.feature_param) self.model = HeteroFeatureSelectionHost(self.feature_param) LOGGER.debug("Guest model started")
def _initialize_model(self, config): network_embedding_param = NetworkEmbeddingParam() self.network_embedding_param = ParamExtract.parse_param_from_config( network_embedding_param, config) self.nrler = HeteroNEArbiter(self.network_embedding_param) if self.nrler is None: print("null")
def one_vs_rest_predict(self, data_instance): if self.mode == consts.HETERO: LOGGER.debug("Star intersection before predict") intersect_flowid = "predict_module_0" data_instance = self.intersect(data_instance, intersect_flowid) LOGGER.debug("End intersection before predict") # data_instance = self.feature_selection_transform(data_instance) # data_instance, fit_config = self.scale(data_instance) one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config( one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) one_vs_rest.load_model(self.workflow_param.model_table, self.workflow_param.model_namespace) predict_result = one_vs_rest.predict(data_instance, self.workflow_param.predict_param) if not predict_result: return None if predict_result.count() > 10: local_predict = predict_result.collect() n = 0 while n < 10: result = local_predict.__next__() LOGGER.debug("predict result: {}".format(result)) n += 1 return predict_result
def feature_selection_fit(self, data_instance, flow_id='sample_flowid', without_transform=False): if self.mode == consts.H**O: LOGGER.info( "H**o feature selection is not supporting yet. Coming soon") return data_instance if data_instance is None: return data_instance if self.workflow_param.need_feature_selection: LOGGER.info("Start feature selection") feature_select_param = param_generator.FeatureSelectionParam() feature_select_param = ParamExtract.parse_param_from_config( feature_select_param, self.config_path) param_checker.FeatureSelectionParamChecker.check_param( feature_select_param) if self.role == consts.HOST: feature_selector = HeteroFeatureSelectionHost( feature_select_param) elif self.role == consts.GUEST: feature_selector = HeteroFeatureSelectionGuest( feature_select_param) elif self.role == consts.ARBITER: return data_instance else: raise ValueError("Unknown role of workflow") feature_selector.set_flowid(flow_id) filter_methods = feature_select_param.filter_method previous_model = {} if 'iv_value_thres' in filter_methods or 'iv_percentile' in filter_methods: binning_model = { 'name': self.workflow_param.model_table, 'namespace': self.workflow_param.model_namespace } previous_model['binning_model'] = binning_model feature_selector.init_previous_model(**previous_model) if without_transform: data_instance = feature_selector.fit(data_instance) else: data_instance = feature_selector.fit_transform(data_instance) save_result = feature_selector.save_model( self.workflow_param.model_table, self.workflow_param.model_namespace) # Save model result in pipeline for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) LOGGER.info("Finish feature selection") return data_instance else: LOGGER.info("No need to do feature selection") return data_instance
def one_hot_encoder_fit_transform(self, data_instance): if data_instance is None: return data_instance if self.workflow_param.need_one_hot: LOGGER.info("Start one-hot encode") one_hot_param = param_generator.OneHotEncoderParam() one_hot_param = ParamExtract.parse_param_from_config( one_hot_param, self.config_path) param_checker.OneHotEncoderParamChecker.check_param(one_hot_param) one_hot_encoder = OneHotEncoder(one_hot_param) data_instance = one_hot_encoder.fit_transform(data_instance) save_result = one_hot_encoder.save_model( self.workflow_param.model_table, self.workflow_param.model_namespace) # Save model result in pipeline for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) LOGGER.info("Finish one-hot encode") return data_instance else: LOGGER.info("No need to do one-hot encode") return data_instance
def intersect(self, data_instance, intersect_flowid=''): if data_instance is None: return data_instance if self.workflow_param.need_intersect: header = data_instance.schema.get('header') LOGGER.info("need_intersect: true!") intersect_param = IntersectParam() self.intersect_params = ParamExtract.parse_param_from_config( intersect_param, self.config_path) LOGGER.info("Start intersection!") if self.role == consts.HOST: intersect_operator = RawIntersectionHost(self.intersect_params) elif self.role == consts.GUEST: intersect_operator = RawIntersectionGuest( self.intersect_params) elif self.role == consts.ARBITER: return data_instance else: raise ValueError("Unknown role of workflow") intersect_operator.set_flowid(intersect_flowid) intersect_ids = intersect_operator.run(data_instance) LOGGER.info("finish intersection!") intersect_data_instance = intersect_ids.join( data_instance, lambda i, d: d) LOGGER.info("get intersect data_instance!") # LOGGER.debug("intersect_data_instance count:{}".format(intersect_data_instance.count())) intersect_data_instance.schema['header'] = header return intersect_data_instance else: LOGGER.info("need_intersect: false!") return data_instance
def scale(self, data_instance, fit_config=None): if self.workflow_param.need_scale: scale_params = ScaleParam() self.scale_params = ParamExtract.parse_param_from_config( scale_params, self.config_path) param_checker.ScaleParamChecker.check_param(self.scale_params) scale_obj = Scaler(self.scale_params) if self.workflow_param.method == "predict": fit_config = scale_obj.load_model( name=self.workflow_param.model_table, namespace=self.workflow_param.model_namespace, header=data_instance.schema.get("header")) if not fit_config: data_instance, fit_config = scale_obj.fit(data_instance) save_results = scale_obj.save_model( name=self.workflow_param.model_table, namespace=self.workflow_param.model_namespace) if save_results: for meta_buffer_type, param_buffer_type in save_results: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) else: data_instance = scale_obj.transform(data_instance, fit_config) else: LOGGER.debug("workflow param need_scale is False") return data_instance, fit_config
def _initialize_model(self, config): secureboosting_param = BoostingTreeParam() self.secureboosting_tree_param = ParamExtract.parse_param_from_config( secureboosting_param, config) self.model = HeteroSecureBoostingTreeHost( self.secureboosting_tree_param)
def _initialize_model(self, runtime_conf_path): binning_param = FeatureBinningParam() self.binning_param = ParamExtract.parse_param_from_config( binning_param, runtime_conf_path) FeatureBinningParamChecker.check_param(self.binning_param) self.model = HeteroFeatureBinningHost(self.binning_param) LOGGER.debug("Host part model started")
def train(self, train_data, validation_data=None): if self.mode == consts.HETERO and self.role != consts.ARBITER: LOGGER.debug("Enter train function") LOGGER.debug("Star intersection before train") intersect_flowid = "train_0" train_data = self.intersect(train_data, intersect_flowid) LOGGER.debug("End intersection before train") sample_flowid = "train_sample_0" train_data = self.sample(train_data, sample_flowid) train_data = self.feature_selection_fit(train_data) validation_data = self.feature_selection_transform(validation_data) if self.mode == consts.HETERO and self.role != consts.ARBITER: train_data, cols_scale_value = self.scale(train_data) train_data = self.one_hot_encoder_fit_transform(train_data) validation_data = self.one_hot_encoder_transform(validation_data) if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config( one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.model.fit(train_data) self.save_model() LOGGER.debug("finish saving, self role: {}".format(self.role)) if self.role == consts.GUEST or self.role == consts.HOST or \ self.mode == consts.H**O: eval_result = {} LOGGER.debug("predicting...") predict_result = self.model.predict( train_data, self.workflow_param.predict_param) LOGGER.debug("evaluating...") train_eval = self.evaluate(predict_result) eval_result[consts.TRAIN_EVALUATE] = train_eval if validation_data is not None: self.model.set_flowid("1") if self.mode == consts.HETERO: LOGGER.debug("Star intersection before predict") intersect_flowid = "predict_0" validation_data = self.intersect(validation_data, intersect_flowid) LOGGER.debug("End intersection before predict") validation_data, cols_scale_value = self.scale( validation_data, cols_scale_value) val_pred = self.model.predict( validation_data, self.workflow_param.predict_param) val_eval = self.evaluate(val_pred) eval_result[consts.VALIDATE_EVALUATE] = val_eval LOGGER.info("{} eval_result: {}".format(self.role, eval_result)) self.save_eval_result(eval_result)
def feature_selection_fit(self, data_instance, flow_id='sample_flowid'): if self.mode == consts.H**O: LOGGER.info( "H**o feature selection is not supporting yet. Coming soon") return data_instance if data_instance is None: return data_instance if self.workflow_param.need_feature_selection: LOGGER.info("Start feature selection") feature_select_param = param_generator.FeatureSelectionParam() feature_select_param = ParamExtract.parse_param_from_config( feature_select_param, self.config_path) param_checker.FeatureSelectionParamChecker.check_param( feature_select_param) if self.role == consts.HOST: feature_selector = HeteroFeatureSelectionHost( feature_select_param) elif self.role == consts.GUEST: feature_selector = HeteroFeatureSelectionGuest( feature_select_param) elif self.role == consts.ARBITER: return data_instance else: raise ValueError("Unknown role of workflow") feature_selector.set_flowid(flow_id) local_only = feature_select_param.local_only # Decide whether do fit_local or fit if local_only: data_instance = feature_selector.fit_local_transform( data_instance) save_result = feature_selector.save_model( self.workflow_param.model_table, self.workflow_param.model_namespace) # Save model result in pipeline for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) else: data_instance = feature_selector.fit_transform(data_instance) save_result = feature_selector.save_model( self.workflow_param.model_table, self.workflow_param.model_namespace) # Save model result in pipeline for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) LOGGER.info("Finish feature selection") return data_instance else: LOGGER.info("No need to do feature selection") return data_instance
def get_filled_param(param_var, config_json): from federatedml.param import param valid_classes = [ class_info[0] for class_info in inspect.getmembers(param, inspect.isclass) ] param_var = ParamExtract.recursive_parse_param_from_config( param_var, config_json, valid_classes, param_parse_depth=0) return param_var
def one_vs_rest_train(self, train_data, validation_data=None): one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) LOGGER.debug("Start OneVsRest train") one_vs_rest.fit(train_data) LOGGER.debug("Start OneVsRest predict") one_vs_rest.predict(validation_data, self.workflow_param.predict_param) save_result = one_vs_rest.save_model(self.workflow_param.model_table, self.workflow_param.model_namespace) if save_result is None: return for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type)
def feature_selection_transform(self, data_instance, flow_id='sample_flowid'): if self.mode == consts.H**O: LOGGER.info( "H**o feature selection is not supporting yet. Coming soon") return data_instance if data_instance is None: return data_instance if self.workflow_param.need_feature_selection: LOGGER.info("Start feature selection transform") feature_select_param = param_generator.FeatureSelectionParam() feature_select_param = ParamExtract.parse_param_from_config( feature_select_param, self.config_path) param_checker.FeatureSelectionParamChecker.check_param( feature_select_param) if self.role == consts.HOST: feature_selector = HeteroFeatureSelectionHost( feature_select_param) elif self.role == consts.GUEST: feature_selector = HeteroFeatureSelectionGuest( feature_select_param) elif self.role == consts.ARBITER: return data_instance else: raise ValueError("Unknown role of workflow") feature_selector.set_flowid(flow_id) feature_selector.load_model(self.workflow_param.model_table, self.workflow_param.model_namespace) LOGGER.debug( "Role: {}, in transform feature selector left_cols: {}".format( self.role, feature_selector.left_cols)) data_instance = feature_selector.transform(data_instance) LOGGER.info("Finish feature selection") return data_instance else: LOGGER.info("No need to do feature selection") return data_instance
def sample(self, data_instance, sample_flowid="sample_flowid"): if not self.workflow_param.need_sample: LOGGER.info("need_sample: false!") return data_instance if self.role == consts.ARBITER: LOGGER.info("arbiter not need sample") return data_instance LOGGER.info("need_sample: true!") sample_param = SampleParam() sample_param = ParamExtract.parse_param_from_config(sample_param, self.config_path) sampler = Sampler(sample_param) sampler.set_flowid(sample_flowid) data_instance = sampler.run(data_instance, self.mode, self.role) LOGGER.info("sample result size is {}".format(data_instance.count())) return data_instance
def feature_binning(self, data_instances, flow_id='sample_flowid'): if self.mode == consts.H**O: LOGGER.info( "H**o feature selection is not supporting yet. Coming soon") return data_instances if data_instances is None: return data_instances LOGGER.info("Start feature binning") feature_binning_param = param_generator.FeatureBinningParam() feature_binning_param = ParamExtract.parse_param_from_config( feature_binning_param, self.config_path) param_checker.FeatureBinningParamChecker.check_param( feature_binning_param) if self.role == consts.HOST: feature_binning_obj = HeteroFeatureBinningHost( feature_binning_param) elif self.role == consts.GUEST: feature_binning_obj = HeteroFeatureBinningGuest( feature_binning_param) elif self.role == consts.ARBITER: return data_instances else: raise ValueError("Unknown role of workflow") feature_binning_obj.set_flowid(flow_id) if feature_binning_param.local_only: data_instances = feature_binning_obj.fit_local(data_instances) else: data_instances = feature_binning_obj.fit(data_instances) save_result = feature_binning_obj.save_model( self.workflow_param.model_table, self.workflow_param.model_namespace) # Save model result in pipeline for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type) LOGGER.info("Finish feature selection") return data_instances
def one_hot_encoder_transform(self, data_instance): if data_instance is None: return data_instance if self.workflow_param.need_one_hot: LOGGER.info("Start one-hot encode") one_hot_param = param_generator.OneHotEncoderParam() one_hot_param = ParamExtract.parse_param_from_config(one_hot_param, self.config_path) param_checker.OneHotEncoderParamChecker.check_param(one_hot_param) one_hot_encoder = OneHotEncoder(one_hot_param) one_hot_encoder.load_model(self.workflow_param.model_table, self.workflow_param.model_namespace) data_instance = one_hot_encoder.transform(data_instance) LOGGER.info("Finish one-hot encode") return data_instance else: LOGGER.info("No need to do one-hot encode") return data_instance
def _initialize_model(self, config): logistic_param = LogisticParam() self.logistic_param = ParamExtract.parse_param_from_config( logistic_param, config) self.model = HeteroLRHost(self.logistic_param)
def _initialize_model(self, runtime_conf_path): feature_param = FeatureSelectionParam() self.feature_param = ParamExtract.parse_param_from_config( feature_param, runtime_conf_path) FeatureSelectionParamChecker.check_param(self.feature_param)
def _initialize_model(self, runtime_conf_path): logistic_param = LogisticParam() self.logistic_param = ParamExtract.parse_param_from_config( logistic_param, runtime_conf_path) self.model = HomoLRArbiter(self.logistic_param)
def _initialize_workflow_param(self, config_path): workflow_param = WorkFlowParam() self.workflow_param = ParamExtract.parse_param_from_config( workflow_param, config_path) WorkFlowParamChecker.check_param(self.workflow_param)
def _load_param(self, param): return ParamExtract.recursive_parse_param_from_config( param, self.config_json, self.valid_classes, param_parse_depth=0)
def run(self): self._init_argument() if self.workflow_param.method == "train": # create a new pipeline LOGGER.debug("In running function, enter train method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace)) train_data_instance = self.gen_data_instance( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace)) predict_data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') self.train(train_data_instance, validation_data=predict_data_instance) self._save_pipeline() elif self.workflow_param.method == "predict": data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config( one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.load_model() self.predict(data_instance) elif self.workflow_param.method == "intersect": LOGGER.debug( "[Intersect]Input table:{}, input namesapce: {}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace)) data_instance = self.gen_data_instance( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.intersect(data_instance) elif self.workflow_param.method == "cross_validation": data_instance = None if self.role != consts.ARBITER: data_instance = self.gen_data_instance( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.cross_validation(data_instance) elif self.workflow_param.method == "one_vs_rest_train": LOGGER.debug("In running function, enter one_vs_rest method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace)) train_data_instance = self.gen_data_instance( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace)) predict_data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace) self.one_vs_rest_train(train_data_instance, validation_data=predict_data_instance) # self.one_vs_rest_predict(predict_data_instance) self._save_pipeline() else: raise TypeError("method %s is not support yet" % (self.workflow_param.method))
def run(self): self._init_argument() if self.workflow_param.method == "train": # create a new pipeline LOGGER.debug("In running function, enter train method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace )) train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace )) predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') self.train(train_data_instance, validation_data=predict_data_instance) self._save_pipeline() elif self.workflow_param.method == 'neighbors_sampling': LOGGER.debug("In running function, enter neighbors sampling") LOGGER.debug("[Neighbors sampling]Input table:{}, input namespace:{}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace )) data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) LOGGER.info("{}".format(self.workflow_param.local_samples_namespace)) LOGGER.info("{}".format(self.workflow_param.distributed_samples_namespace)) adj_instances = data_instance intersect_flowid = 'neigh_sam_intersect_0' common_instance = self.intersect(data_instance, intersect_flowid) LOGGER.info("The number of common nodes: {}".format(common_instance.count())) local_instances = self.neighbors_sampler.local_neighbors_sampling(adj_instances, self.role) # persistent local_instances.save_as(name=self.role, namespace=self.workflow_param.local_samples_namespace, partition=10) bridge_instances = NeighborsSampling.get_bridge_nodes(common_instance) intersect_flowid_2 = 'neigh_sam_intersect_1' bridge_instances = self.intersect(bridge_instances, intersect_flowid_2) logDtableInstances(LOGGER, bridge_instances, 5) distributed_instances_target, distributed_instances_anchor = self.neighbors_sampler.distributed_neighbors_sampling(bridge_instances, adj_instances) distributed_instances_target.save_as(name="target", namespace=self.workflow_param.distributed_samples_namespace + "/" + self.role, partition=10) distributed_instances_anchor.save_as(name='anchor', namespace=self.workflow_param.distributed_samples_namespace + "/" + self.role, partition=10) if self.role == 'host': LOGGER.info("Neighbors_sampling_finish") elif self.workflow_param.method == "predict": data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.load_model() self.predict(data_instance) elif self.workflow_param.method == "intersect": LOGGER.debug("[Intersect]Input table:{}, input namespace: {}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace )) data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.intersect(data_instance) elif self.workflow_param.method == "cross_validation": data_instance = None if self.role != consts.ARBITER: data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.cross_validation(data_instance) elif self.workflow_param.method == "one_vs_rest_train": LOGGER.debug("In running function, enter one_vs_rest method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace )) train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace )) predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace) self.one_vs_rest_train(train_data_instance, validation_data=predict_data_instance) # self.one_vs_rest_predict(predict_data_instance) self._save_pipeline() else: raise TypeError("method %s is not support yet" % (self.workflow_param.method))
def _initialize_model(self, config): network_embedding_param = NetworkEmbeddingParam() self.network_embedding_param = ParamExtract.parse_param_from_config( network_embedding_param, config) self.nrler = HeteroNEHost(self.network_embedding_param)