def one_vs_rest_predict(self, data_instance): if self.mode == consts.HETERO: LOGGER.debug("Star intersection before predict") intersect_flowid = "predict_module_0" data_instance = self.intersect(data_instance, intersect_flowid) LOGGER.debug("End intersection before predict") # data_instance = self.feature_selection_transform(data_instance) # data_instance, fit_config = self.scale(data_instance) one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = self._load_param(one_vs_rest_param) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) one_vs_rest.load_model(self.workflow_param.model_table, self.workflow_param.model_namespace) predict_result = one_vs_rest.predict(data_instance, self.workflow_param.predict_param) if not predict_result: return None if predict_result.count() > 10: local_predict = predict_result.collect() n = 0 while n < 10: result = local_predict.__next__() LOGGER.debug("predict result: {}".format(result)) n += 1 return predict_result
def train(self, train_data, validation_data=None): if self.mode == consts.HETERO and self.role != consts.ARBITER: LOGGER.debug("Enter train function") LOGGER.debug("Star intersection before train") intersect_flowid = "train_0" train_data = self.intersect(train_data, intersect_flowid) LOGGER.debug("End intersection before train") sample_flowid = "train_sample_0" train_data = self.sample(train_data, sample_flowid) train_data = self.feature_selection_fit(train_data) validation_data = self.feature_selection_transform(validation_data) if self.mode == consts.HETERO and self.role != consts.ARBITER: train_data, cols_scale_value = self.scale(train_data) train_data = self.one_hot_encoder_fit_transform(train_data) validation_data = self.one_hot_encoder_transform(validation_data) if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config( one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.model.fit(train_data) self.save_model() LOGGER.debug("finish saving, self role: {}".format(self.role)) if self.role == consts.GUEST or self.role == consts.HOST or \ self.mode == consts.H**O: eval_result = {} LOGGER.debug("predicting...") predict_result = self.model.predict( train_data, self.workflow_param.predict_param) LOGGER.debug("evaluating...") train_eval = self.evaluate(predict_result) eval_result[consts.TRAIN_EVALUATE] = train_eval if validation_data is not None: self.model.set_flowid("1") if self.mode == consts.HETERO: LOGGER.debug("Star intersection before predict") intersect_flowid = "predict_0" validation_data = self.intersect(validation_data, intersect_flowid) LOGGER.debug("End intersection before predict") validation_data, cols_scale_value = self.scale( validation_data, cols_scale_value) val_pred = self.model.predict( validation_data, self.workflow_param.predict_param) val_eval = self.evaluate(val_pred) eval_result[consts.VALIDATE_EVALUATE] = val_eval LOGGER.info("{} eval_result: {}".format(self.role, eval_result)) self.save_eval_result(eval_result)
def one_vs_rest_train(self, train_data, validation_data=None): one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) LOGGER.debug("Start OneVsRest train") one_vs_rest.fit(train_data) LOGGER.debug("Start OneVsRest predict") one_vs_rest.predict(validation_data, self.workflow_param.predict_param) save_result = one_vs_rest.save_model(self.workflow_param.model_table, self.workflow_param.model_namespace) if save_result is None: return for meta_buffer_type, param_buffer_type in save_result: self.pipeline.node_meta.append(meta_buffer_type) self.pipeline.node_param.append(param_buffer_type)
def run(self, config_json, job_id): self._init_argument(config_json, job_id) if self.workflow_param.method == "train": # create a new pipeline LOGGER.debug("In running function, enter train method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace)) train_data_instance = self.gen_data_instance( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace)) predict_data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') self.train(train_data_instance, validation_data=predict_data_instance) self._save_pipeline() elif self.workflow_param.method == "predict": data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = self._load_param(one_vs_rest_param) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.load_model() self.predict(data_instance) elif self.workflow_param.method == "intersect": LOGGER.debug( "[Intersect]Input table:{}, input namesapce: {}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace)) data_instance = self.gen_data_instance( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.intersect(data_instance) elif self.workflow_param.method == "cross_validation": data_instance = None if self.role != consts.ARBITER: data_instance = self.gen_data_instance( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.cross_validation(data_instance) elif self.workflow_param.method == "one_vs_rest_train": LOGGER.debug("In running function, enter one_vs_rest method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace)) train_data_instance = self.gen_data_instance( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace)) predict_data_instance = self.gen_data_instance( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace) self.one_vs_rest_train(train_data_instance, validation_data=predict_data_instance) # self.one_vs_rest_predict(predict_data_instance) self._save_pipeline() else: raise TypeError("method %s is not support yet" % (self.workflow_param.method)) LOGGER.debug("run_DONE")
def run(self): self._init_argument() if self.workflow_param.method == "train": # create a new pipeline LOGGER.debug("In running function, enter train method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace )) train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace )) predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') self.train(train_data_instance, validation_data=predict_data_instance) self._save_pipeline() elif self.workflow_param.method == 'neighbors_sampling': LOGGER.debug("In running function, enter neighbors sampling") LOGGER.debug("[Neighbors sampling]Input table:{}, input namespace:{}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace )) data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) LOGGER.info("{}".format(self.workflow_param.local_samples_namespace)) LOGGER.info("{}".format(self.workflow_param.distributed_samples_namespace)) adj_instances = data_instance intersect_flowid = 'neigh_sam_intersect_0' common_instance = self.intersect(data_instance, intersect_flowid) LOGGER.info("The number of common nodes: {}".format(common_instance.count())) local_instances = self.neighbors_sampler.local_neighbors_sampling(adj_instances, self.role) # persistent local_instances.save_as(name=self.role, namespace=self.workflow_param.local_samples_namespace, partition=10) bridge_instances = NeighborsSampling.get_bridge_nodes(common_instance) intersect_flowid_2 = 'neigh_sam_intersect_1' bridge_instances = self.intersect(bridge_instances, intersect_flowid_2) logDtableInstances(LOGGER, bridge_instances, 5) distributed_instances_target, distributed_instances_anchor = self.neighbors_sampler.distributed_neighbors_sampling(bridge_instances, adj_instances) distributed_instances_target.save_as(name="target", namespace=self.workflow_param.distributed_samples_namespace + "/" + self.role, partition=10) distributed_instances_anchor.save_as(name='anchor', namespace=self.workflow_param.distributed_samples_namespace + "/" + self.role, partition=10) if self.role == 'host': LOGGER.info("Neighbors_sampling_finish") elif self.workflow_param.method == "predict": data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace, mode='transform') if self.workflow_param.one_vs_rest: one_vs_rest_param = OneVsRestParam() self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path) one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param) self.model = one_vs_rest self.load_model() self.predict(data_instance) elif self.workflow_param.method == "intersect": LOGGER.debug("[Intersect]Input table:{}, input namespace: {}".format( self.workflow_param.data_input_table, self.workflow_param.data_input_namespace )) data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.intersect(data_instance) elif self.workflow_param.method == "cross_validation": data_instance = None if self.role != consts.ARBITER: data_instance = self.gen_data_instance(self.workflow_param.data_input_table, self.workflow_param.data_input_namespace) self.cross_validation(data_instance) elif self.workflow_param.method == "one_vs_rest_train": LOGGER.debug("In running function, enter one_vs_rest method") train_data_instance = None predict_data_instance = None if self.role != consts.ARBITER: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.train_input_table, self.workflow_param.train_input_namespace )) train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table, self.workflow_param.train_input_namespace) LOGGER.debug("gen_data_finish") if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None: LOGGER.debug("Input table:{}, input namesapce: {}".format( self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace )) predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace) self.one_vs_rest_train(train_data_instance, validation_data=predict_data_instance) # self.one_vs_rest_predict(predict_data_instance) self._save_pipeline() else: raise TypeError("method %s is not support yet" % (self.workflow_param.method))