Пример #1
0
    def one_vs_rest_predict(self, data_instance):
        if self.mode == consts.HETERO:
            LOGGER.debug("Star intersection before predict")
            intersect_flowid = "predict_module_0"
            data_instance = self.intersect(data_instance, intersect_flowid)
            LOGGER.debug("End intersection before predict")

        # data_instance = self.feature_selection_transform(data_instance)

        # data_instance, fit_config = self.scale(data_instance)
        one_vs_rest_param = OneVsRestParam()
        self.one_vs_rest_param = self._load_param(one_vs_rest_param)
        one_vs_rest = OneVsRest(self.model, self.role, self.mode,
                                self.one_vs_rest_param)
        one_vs_rest.load_model(self.workflow_param.model_table,
                               self.workflow_param.model_namespace)
        predict_result = one_vs_rest.predict(data_instance,
                                             self.workflow_param.predict_param)

        if not predict_result:
            return None

        if predict_result.count() > 10:
            local_predict = predict_result.collect()
            n = 0
            while n < 10:
                result = local_predict.__next__()
                LOGGER.debug("predict result: {}".format(result))
                n += 1

        return predict_result
Пример #2
0
    def train(self, train_data, validation_data=None):
        if self.mode == consts.HETERO and self.role != consts.ARBITER:
            LOGGER.debug("Enter train function")
            LOGGER.debug("Star intersection before train")
            intersect_flowid = "train_0"
            train_data = self.intersect(train_data, intersect_flowid)
            LOGGER.debug("End intersection before train")

        sample_flowid = "train_sample_0"
        train_data = self.sample(train_data, sample_flowid)

        train_data = self.feature_selection_fit(train_data)
        validation_data = self.feature_selection_transform(validation_data)

        if self.mode == consts.HETERO and self.role != consts.ARBITER:
            train_data, cols_scale_value = self.scale(train_data)

        train_data = self.one_hot_encoder_fit_transform(train_data)
        validation_data = self.one_hot_encoder_transform(validation_data)

        if self.workflow_param.one_vs_rest:
            one_vs_rest_param = OneVsRestParam()
            self.one_vs_rest_param = ParamExtract.parse_param_from_config(
                one_vs_rest_param, self.config_path)
            one_vs_rest = OneVsRest(self.model, self.role, self.mode,
                                    self.one_vs_rest_param)
            self.model = one_vs_rest

        self.model.fit(train_data)
        self.save_model()
        LOGGER.debug("finish saving, self role: {}".format(self.role))
        if self.role == consts.GUEST or self.role == consts.HOST or \
                self.mode == consts.H**O:
            eval_result = {}
            LOGGER.debug("predicting...")
            predict_result = self.model.predict(
                train_data, self.workflow_param.predict_param)

            LOGGER.debug("evaluating...")
            train_eval = self.evaluate(predict_result)
            eval_result[consts.TRAIN_EVALUATE] = train_eval
            if validation_data is not None:
                self.model.set_flowid("1")
                if self.mode == consts.HETERO:
                    LOGGER.debug("Star intersection before predict")
                    intersect_flowid = "predict_0"
                    validation_data = self.intersect(validation_data,
                                                     intersect_flowid)
                    LOGGER.debug("End intersection before predict")

                    validation_data, cols_scale_value = self.scale(
                        validation_data, cols_scale_value)

                val_pred = self.model.predict(
                    validation_data, self.workflow_param.predict_param)
                val_eval = self.evaluate(val_pred)
                eval_result[consts.VALIDATE_EVALUATE] = val_eval
            LOGGER.info("{} eval_result: {}".format(self.role, eval_result))
            self.save_eval_result(eval_result)
Пример #3
0
    def one_vs_rest_train(self, train_data, validation_data=None):
        one_vs_rest_param = OneVsRestParam()
        self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path)
        one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param)
        LOGGER.debug("Start OneVsRest train")
        one_vs_rest.fit(train_data)
        LOGGER.debug("Start OneVsRest predict")
        one_vs_rest.predict(validation_data, self.workflow_param.predict_param)
        save_result = one_vs_rest.save_model(self.workflow_param.model_table, self.workflow_param.model_namespace)
        if save_result is None:
            return

        for meta_buffer_type, param_buffer_type in save_result:
            self.pipeline.node_meta.append(meta_buffer_type)
            self.pipeline.node_param.append(param_buffer_type)
Пример #4
0
    def run(self, config_json, job_id):
        self._init_argument(config_json, job_id)
        if self.workflow_param.method == "train":

            # create a new pipeline

            LOGGER.debug("In running function, enter train method")
            train_data_instance = None
            predict_data_instance = None
            if self.role != consts.ARBITER:
                LOGGER.debug("Input table:{}, input namesapce: {}".format(
                    self.workflow_param.train_input_table,
                    self.workflow_param.train_input_namespace))
                train_data_instance = self.gen_data_instance(
                    self.workflow_param.train_input_table,
                    self.workflow_param.train_input_namespace)
                LOGGER.debug("gen_data_finish")
                if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None:
                    LOGGER.debug("Input table:{}, input namesapce: {}".format(
                        self.workflow_param.predict_input_table,
                        self.workflow_param.predict_input_namespace))
                    predict_data_instance = self.gen_data_instance(
                        self.workflow_param.predict_input_table,
                        self.workflow_param.predict_input_namespace,
                        mode='transform')

            self.train(train_data_instance,
                       validation_data=predict_data_instance)
            self._save_pipeline()

        elif self.workflow_param.method == "predict":
            data_instance = self.gen_data_instance(
                self.workflow_param.predict_input_table,
                self.workflow_param.predict_input_namespace,
                mode='transform')
            if self.workflow_param.one_vs_rest:
                one_vs_rest_param = OneVsRestParam()
                self.one_vs_rest_param = self._load_param(one_vs_rest_param)
                one_vs_rest = OneVsRest(self.model, self.role, self.mode,
                                        self.one_vs_rest_param)
                self.model = one_vs_rest
            self.load_model()
            self.predict(data_instance)

        elif self.workflow_param.method == "intersect":
            LOGGER.debug(
                "[Intersect]Input table:{}, input namesapce: {}".format(
                    self.workflow_param.data_input_table,
                    self.workflow_param.data_input_namespace))
            data_instance = self.gen_data_instance(
                self.workflow_param.data_input_table,
                self.workflow_param.data_input_namespace)

            self.intersect(data_instance)

        elif self.workflow_param.method == "cross_validation":
            data_instance = None
            if self.role != consts.ARBITER:
                data_instance = self.gen_data_instance(
                    self.workflow_param.data_input_table,
                    self.workflow_param.data_input_namespace)
            self.cross_validation(data_instance)

        elif self.workflow_param.method == "one_vs_rest_train":
            LOGGER.debug("In running function, enter one_vs_rest method")
            train_data_instance = None
            predict_data_instance = None
            if self.role != consts.ARBITER:
                LOGGER.debug("Input table:{}, input namesapce: {}".format(
                    self.workflow_param.train_input_table,
                    self.workflow_param.train_input_namespace))
                train_data_instance = self.gen_data_instance(
                    self.workflow_param.train_input_table,
                    self.workflow_param.train_input_namespace)
                LOGGER.debug("gen_data_finish")
                if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None:
                    LOGGER.debug("Input table:{}, input namesapce: {}".format(
                        self.workflow_param.predict_input_table,
                        self.workflow_param.predict_input_namespace))
                    predict_data_instance = self.gen_data_instance(
                        self.workflow_param.predict_input_table,
                        self.workflow_param.predict_input_namespace)

            self.one_vs_rest_train(train_data_instance,
                                   validation_data=predict_data_instance)
            # self.one_vs_rest_predict(predict_data_instance)
            self._save_pipeline()

        else:
            raise TypeError("method %s is not support yet" %
                            (self.workflow_param.method))

        LOGGER.debug("run_DONE")
Пример #5
0
    def run(self):
        self._init_argument()

        if self.workflow_param.method == "train":

            # create a new pipeline

            LOGGER.debug("In running function, enter train method")
            train_data_instance = None
            predict_data_instance = None
            if self.role != consts.ARBITER:
                LOGGER.debug("Input table:{}, input namesapce: {}".format(
                    self.workflow_param.train_input_table, self.workflow_param.train_input_namespace
                ))
                train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table,
                                                             self.workflow_param.train_input_namespace)
                LOGGER.debug("gen_data_finish")
                if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None:
                    LOGGER.debug("Input table:{}, input namesapce: {}".format(
                        self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace
                    ))
                    predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table,
                                                                   self.workflow_param.predict_input_namespace,
                                                                   mode='transform')

            self.train(train_data_instance, validation_data=predict_data_instance)
            self._save_pipeline()

        elif self.workflow_param.method == 'neighbors_sampling':
            LOGGER.debug("In running function, enter neighbors sampling")
            
            LOGGER.debug("[Neighbors sampling]Input table:{}, input namespace:{}".format(
                self.workflow_param.data_input_table,
                self.workflow_param.data_input_namespace
            ))
            data_instance = self.gen_data_instance(self.workflow_param.data_input_table,
                                                   self.workflow_param.data_input_namespace)
            
            LOGGER.info("{}".format(self.workflow_param.local_samples_namespace))
            LOGGER.info("{}".format(self.workflow_param.distributed_samples_namespace))

            adj_instances = data_instance
            intersect_flowid = 'neigh_sam_intersect_0'
            common_instance = self.intersect(data_instance, intersect_flowid)
            LOGGER.info("The number of common nodes: {}".format(common_instance.count()))

            local_instances = self.neighbors_sampler.local_neighbors_sampling(adj_instances, self.role)
            # persistent
            local_instances.save_as(name=self.role,
                                    namespace=self.workflow_param.local_samples_namespace,
                                    partition=10)
            
            bridge_instances = NeighborsSampling.get_bridge_nodes(common_instance)

            intersect_flowid_2 = 'neigh_sam_intersect_1'
                
            bridge_instances = self.intersect(bridge_instances, intersect_flowid_2)

            logDtableInstances(LOGGER, bridge_instances, 5)
            
            
            distributed_instances_target, distributed_instances_anchor = self.neighbors_sampler.distributed_neighbors_sampling(bridge_instances, adj_instances)
            
            distributed_instances_target.save_as(name="target",
                                              namespace=self.workflow_param.distributed_samples_namespace + "/" 
                                              + self.role,
                                              partition=10) 
            
            distributed_instances_anchor.save_as(name='anchor',
                                              namespace=self.workflow_param.distributed_samples_namespace + "/" 
                                              + self.role,
                                              partition=10)

            if self.role == 'host':
                LOGGER.info("Neighbors_sampling_finish")

        elif self.workflow_param.method == "predict":
            data_instance = self.gen_data_instance(self.workflow_param.predict_input_table,
                                                   self.workflow_param.predict_input_namespace,
                                                   mode='transform')

            if self.workflow_param.one_vs_rest:
                one_vs_rest_param = OneVsRestParam()
                self.one_vs_rest_param = ParamExtract.parse_param_from_config(one_vs_rest_param, self.config_path)
                one_vs_rest = OneVsRest(self.model, self.role, self.mode, self.one_vs_rest_param)
                self.model = one_vs_rest

            self.load_model()
            self.predict(data_instance)

        elif self.workflow_param.method == "intersect":
            LOGGER.debug("[Intersect]Input table:{}, input namespace: {}".format(
                self.workflow_param.data_input_table,
                self.workflow_param.data_input_namespace
            ))
            data_instance = self.gen_data_instance(self.workflow_param.data_input_table,
                                                   self.workflow_param.data_input_namespace)
            self.intersect(data_instance)                                                                               


        elif self.workflow_param.method == "cross_validation":
            data_instance = None
            if self.role != consts.ARBITER:
                data_instance = self.gen_data_instance(self.workflow_param.data_input_table,
                                                       self.workflow_param.data_input_namespace)
            self.cross_validation(data_instance)
        elif self.workflow_param.method == "one_vs_rest_train":
            LOGGER.debug("In running function, enter one_vs_rest method")
            train_data_instance = None
            predict_data_instance = None
            if self.role != consts.ARBITER:
                LOGGER.debug("Input table:{}, input namesapce: {}".format(
                    self.workflow_param.train_input_table, self.workflow_param.train_input_namespace
                ))
                train_data_instance = self.gen_data_instance(self.workflow_param.train_input_table,
                                                             self.workflow_param.train_input_namespace)
                LOGGER.debug("gen_data_finish")
                if self.workflow_param.predict_input_table is not None and self.workflow_param.predict_input_namespace is not None:
                    LOGGER.debug("Input table:{}, input namesapce: {}".format(
                        self.workflow_param.predict_input_table, self.workflow_param.predict_input_namespace
                    ))
                    predict_data_instance = self.gen_data_instance(self.workflow_param.predict_input_table,
                                                                   self.workflow_param.predict_input_namespace)

            self.one_vs_rest_train(train_data_instance, validation_data=predict_data_instance)
            # self.one_vs_rest_predict(predict_data_instance)
            self._save_pipeline()

        else:
            raise TypeError("method %s is not support yet" % (self.workflow_param.method))