Ejemplo n.º 1
0
 def _display_result(self, block_num=None):
     if block_num is None:
         self.callback_metric(metric_name=self.metric_name,
                              metric_namespace=self.metric_namespace,
                              metric_data=[
                                  Metric("Coverage", self.coverage),
                                  Metric("Block number", self.block_num)
                              ])
         self.tracker.set_metric_meta(
             metric_namespace=self.metric_namespace,
             metric_name=self.metric_name,
             metric_meta=MetricMeta(self.metric_name,
                                    metric_type="INTERSECTION"))
     else:
         self.callback_metric(metric_name=self.metric_name,
                              metric_namespace=self.metric_namespace,
                              metric_data=[
                                  Metric("Coverage", self.coverage),
                                  Metric("Block number", block_num)
                              ])
         self.tracker.set_metric_meta(
             metric_namespace=self.metric_namespace,
             metric_name=self.metric_name,
             metric_meta=MetricMeta(self.metric_name,
                                    metric_type="INTERSECTION"))
Ejemplo n.º 2
0
def callback(tracker, method, callback_metrics, other_metrics=None):
    LOGGER.debug("callback: method is {}".format(method))
    if method == "random":
        tracker.log_metric_data("sample_count",
                                "random",
                                callback_metrics)

        tracker.set_metric_meta("sample_count",
                                "random",
                                MetricMeta(name="sample_count",
                                           metric_type="SAMPLE_TEXT"))

    else:
        LOGGER.debug(
            "callback: name {}, namespace {}, metrics_data {}".format("sample_count", "stratified", callback_metrics))

        tracker.log_metric_data("sample_count",
                                "stratified",
                                callback_metrics)

        tracker.set_metric_meta("sample_count",
                                "stratified",
                                MetricMeta(name="sample_count",
                                           metric_type="SAMPLE_TABLE"))

        tracker.log_metric_data("original_count",
                                "stratified",
                                other_metrics)

        tracker.set_metric_meta("original_count",
                                "stratified",
                                MetricMeta(name="original_count",
                                           metric_type="SAMPLE_TABLE"))
Ejemplo n.º 3
0
def callback(tracker, method, callback_metrics):
    print ("method is {}".format(method))
    if method == "random":
        tracker.log_metric_data("sample_count",
                                "random",
                                callback_metrics)
        
        tracker.set_metric_meta("sample_count",
                                "random",
                                MetricMeta(name="sample_count",
                                            metric_type="SAMPLE_TEXT"))

    else:
        print ("name {}, namespace {}, metrics_data {}".format("sample_count", "stratified", callback_metrics))
        for metric in callback_metrics:
            print ("metric is {}".format(metric))

        tracker.log_metric_data("sample_count",
                                "stratified",
                                callback_metrics)

        tracker.set_metric_meta("sample_count",
                                "stratified",
                                MetricMeta(name="sample_count",
                                            metric_type="SAMPLE_TABLE"))
Ejemplo n.º 4
0
    def __save_f1_score_table(self, metric, f1_scores, thresholds, metric_name, metric_namespace):

        extra_metas = {'f1_scores': list(np.round(f1_scores, self.round_num)),
                       'thresholds': list(np.round(thresholds, self.round_num))}

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 5
0
 def callback_metric(self, metric_name, metric_namespace, metric_data):
     self.tracker.log_metric_data(metric_name=metric_name,
                                  metric_namespace=metric_namespace,
                                  metrics=metric_data)
     self.tracker.set_metric_meta(
         metric_namespace, metric_name,
         MetricMeta(name='download', metric_type='DOWNLOAD'))
Ejemplo n.º 6
0
    def _callback(self):

        self.tracker.set_metric_meta(metric_namespace="statistic",
                                     metric_name="correlation",
                                     metric_meta=MetricMeta(
                                         name="pearson",
                                         metric_type="CORRELATION_GRAPH"))
Ejemplo n.º 7
0
    def __save_curve_meta(self,
                          metric_name,
                          metric_namespace,
                          metric_type,
                          unit_name=None,
                          ordinate_name=None,
                          curve_name=None,
                          best=None,
                          pair_type=None,
                          thresholds=None):
        extra_metas = {}
        metric_type = "_".join([metric_type, "EVALUATION"])

        key_list = [
            "unit_name", "ordinate_name", "curve_name", "best", "pair_type",
            "thresholds"
        ]
        for key in key_list:
            value = locals()[key]
            if value:
                if key == "thresholds":
                    value = np.round(value, self.round_num).tolist()
                extra_metas[key] = value

        self.tracker.set_metric_meta(
            metric_namespace, metric_name,
            MetricMeta(name=metric_name,
                       metric_type=metric_type,
                       extra_metas=extra_metas))
Ejemplo n.º 8
0
    def fit(self, data):
        """
        Apply scale for input data
        Parameters
        ----------
        data: data_instance, input data

        Returns
        ----------
        data:data_instance, data after scale
        scale_value_results: list, the fit results information of scale
        """
        LOGGER.info("Start scale data fit ...")

        if self.model_param.method == consts.MINMAXSCALE:
            self.scale_obj = MinMaxScale(self.model_param)
        elif self.model_param.method == consts.STANDARDSCALE:
            self.scale_obj = StandardScale(self.model_param)
        else:
            LOGGER.warning("Scale method is {}, do nothing and return!".format(self.model_param.method))

        if self.scale_obj:
            fit_data = self.scale_obj.fit(data)
            fit_data.schema = data.schema

            self.callback_meta(metric_name="scale", metric_namespace="train",
                               metric_meta=MetricMeta(name="scale", metric_type="SCALE", extra_metas={"method":self.model_param.method}))
        else:
            fit_data = data

        LOGGER.info("End fit data ...")
        return fit_data
Ejemplo n.º 9
0
def callback(keyword="missing_impute",
             value_list=None,
             tracker=None):
    # tracker = Tracking("abc", "123")
    metric_type = None
    """
    if keyword.endswith("ratio"):
        metric_list = []
        for i in range(len(value_list)):
            metric_list.append(Metric(i, value_list[i]))

        tracker.log_metric_data(keyword, "DATAIO", metric_list)

        metric_type = "DATAIO_TABLE"
    """
    metric_list = []
    for i in range(len(value_list)):
        metric_list.append(Metric(value_list[i], i))

    tracker.log_metric_data(keyword, "DATAIO", metric_list)

    metric_type = "DATAIO_TEXT"

    tracker.set_metric_meta(keyword,
                            "DATAIO",
                            MetricMeta(name=keyword,
                                       metric_type=metric_type))
Ejemplo n.º 10
0
    def __save_confusion_mat_table(self, metric, confusion_mat, thresholds, metric_name, metric_namespace):

        extra_metas = {'tp': list(confusion_mat['tp']), 'tn': list(confusion_mat['tn']),
                       'fp': list(confusion_mat['fp']),
                       'fn': list(confusion_mat['fn']), 'thresholds': list(np.round(thresholds, self.round_num))}

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 11
0
 def __save_single_value(self, result, metric_name, metric_namespace,
                         eval_name):
     self.tracker.log_metric_data(
         metric_namespace, metric_name,
         [Metric(eval_name, np.round(result, self.round_num))])
     self.tracker.set_metric_meta(
         metric_namespace, metric_name,
         MetricMeta(name=metric_name, metric_type="EVALUATION_SUMMARY"))
Ejemplo n.º 12
0
    def record_step_best(self, step_best, host_mask, guest_mask,
                         data_instances, model):
        metas = {
            "host_mask": host_mask.tolist(),
            "guest_mask": guest_mask.tolist(),
            "score_name": self.score_name
        }
        metas["number_in"] = int(sum(host_mask) + sum(guest_mask))
        metas["direction"] = self.direction
        metas["n_count"] = int(self.n_count)

        host_party_id = model.component_properties.host_party_idlist[0]
        guest_party_id = model.component_properties.guest_partyid
        metas["host_features_anonym"] = [
            f"host_{host_party_id}_{i}" for i in range(len(host_mask))
        ]
        metas["guest_features_anonym"] = [
            f"guest_{guest_party_id}_{i}" for i in range(len(guest_mask))
        ]

        model_info = self.models_trained[step_best]
        loss = model_info.get_loss()
        ic_val = model_info.get_score()
        metas["loss"] = loss
        metas["current_ic_val"] = ic_val
        metas["fit_intercept"] = model.fit_intercept

        model_key = model_info.get_key()
        model_dict = self._get_model(model_key)

        if self.role != consts.ARBITER:
            all_features = data_instances.schema.get('header')
            metas["all_features"] = all_features
            metas["to_enter"] = self.get_to_enter(host_mask, guest_mask,
                                                  all_features)
            model_param = list(model_dict.get('model').values())[0].get(
                model.model_param_name)
            param_dict = MessageToDict(model_param)
            metas["intercept"] = param_dict.get("intercept", None)
            metas["weight"] = param_dict.get("weight", {})
            metas["header"] = param_dict.get("header", [])
            if self.n_step == 0 and self.direction == "forward":
                metas["intercept"] = self.intercept

        metric_name = f"stepwise_{self.n_step}"
        metric = [Metric(metric_name, float(self.n_step))]
        model.callback_metric(metric_name=metric_name,
                              metric_namespace=self.metric_namespace,
                              metric_data=metric)
        model.tracker.set_metric_meta(metric_name=metric_name,
                                      metric_namespace=self.metric_namespace,
                                      metric_meta=MetricMeta(
                                          name=metric_name,
                                          metric_type=self.metric_type,
                                          extra_metas=metas))
        LOGGER.info(f"metric_name: {metric_name}, metas: {metas}")
        return
Ejemplo n.º 13
0
    def __save_pr_table(self, metric, metric_res, metric_name, metric_namespace):

        p_scores, r_scores, score_threshold = metric_res

        extra_metas = {'p_scores': list(map(list, np.round(p_scores, self.round_num))),
                       'r_scores': list(map(list, np.round(r_scores, self.round_num))),
                       'thresholds': list(np.round(score_threshold, self.round_num))}

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 14
0
    def __save_single_value(self, result, metric_name, metric_namespace, eval_name):

        metric_type = 'EVALUATION_SUMMARY'
        if eval_name in consts.ALL_CLUSTER_METRICS:
            metric_type = 'CLUSTERING_EVALUATION_SUMMARY'

        self.tracker.log_metric_data(metric_namespace, metric_name,
                                     [Metric(eval_name, np.round(result, self.round_num))])
        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric_type))
Ejemplo n.º 15
0
 def save_metric_meta(self,
                      metric_namespace: str,
                      metric_name: str,
                      metric_meta: MetricMeta,
                      job_level: bool = False):
     schedule_logger(self.job_id).info(
         'save job {} component {} on {} {} {} {} metric meta'.format(
             self.job_id, self.component_name, self.role, self.party_id,
             metric_namespace, metric_name))
     self.insert_metrics_into_db(metric_namespace, metric_name, 0,
                                 metric_meta.to_dict().items(), job_level)
Ejemplo n.º 16
0
 def callback(self, metas):
     metric = [Metric(self.metric_name, 0)]
     self.callback_metric(metric_name=self.metric_name,
                          metric_namespace=self.metric_namespace,
                          metric_data=metric)
     self.tracker.set_metric_meta(metric_name=self.metric_name,
                                  metric_namespace=self.metric_namespace,
                                  metric_meta=MetricMeta(
                                      name=self.metric_name,
                                      metric_type=self.metric_type,
                                      extra_metas=metas))
Ejemplo n.º 17
0
    def callback_loss(self, iter_num, loss):
        metric_meta = MetricMeta(name='train',
                                 metric_type=MetricType.LOSS,
                                 extra_metas={
                                     "unit_name": "iters",
                                 })

        self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta)
        self.callback_metric(metric_name='loss',
                             metric_namespace='train',
                             metric_data=[Metric(iter_num, loss)])
Ejemplo n.º 18
0
 def get_metric_meta(self,
                     metric_namespace: str,
                     metric_name: str,
                     job_level: bool = False):
     kv = dict()
     for k, v in self.read_metrics_from_db(metric_namespace, metric_name, 0,
                                           job_level):
         kv[k] = v
     return MetricMeta(name=kv.get('name'),
                       metric_type=kv.get('metric_type'),
                       extra_metas=kv)
Ejemplo n.º 19
0
    def callback_ovr_metric_data(self, eval_results):

        for model_name, eval_rs in eval_results.items():

            train_callback_meta = defaultdict(dict)
            validate_callback_meta = defaultdict(dict)
            split_list = model_name.split('_')
            label = split_list[-1]
            origin_model_name_list = split_list[:
                                                -2]  # remove ' "class" label_index'
            origin_model_name = ''
            for s in origin_model_name_list:
                origin_model_name += (s + '_')
            origin_model_name = origin_model_name[:-1]

            for rs_dict in eval_rs:
                for metric_name, metric_rs in rs_dict.items():
                    if metric_name == consts.KS:
                        metric_rs = [
                            metric_rs[0], metric_rs[1][0]
                        ]  # ks value only, curve data is not needed
                    metric_namespace = metric_rs[0]
                    if metric_namespace == 'train':
                        callback_meta = train_callback_meta
                    else:
                        callback_meta = validate_callback_meta
                    callback_meta[label][metric_name] = metric_rs[1]

            self.tracker.set_metric_meta(
                "train", model_name + '_' + 'ovr',
                MetricMeta(name=origin_model_name,
                           metric_type='ovr',
                           extra_metas=train_callback_meta))
            self.tracker.set_metric_meta(
                "validate", model_name + '_' + 'ovr',
                MetricMeta(name=origin_model_name,
                           metric_type='ovr',
                           extra_metas=validate_callback_meta))

            LOGGER.debug('callback data {} {}'.format(train_callback_meta,
                                                      validate_callback_meta))
Ejemplo n.º 20
0
    def save_meta(self, dst_table_namespace, dst_table_name, table_count):
        self.tracker.log_output_data_info(data_name='upload',
                                          table_namespace=dst_table_namespace,
                                          table_name=dst_table_name)

        self.tracker.log_metric_data(metric_namespace="upload",
                                     metric_name="data_access",
                                     metrics=[Metric("count", table_count)])
        self.tracker.set_metric_meta(metric_namespace="upload",
                                     metric_name="data_access",
                                     metric_meta=MetricMeta(
                                         name='upload', metric_type='UPLOAD'))
Ejemplo n.º 21
0
def callback(tracker,
             method,
             callback_metrics,
             other_metrics=None,
             summary_dict=None):
    LOGGER.debug("callback: method is {}".format(method))
    if method == "random":
        tracker.log_metric_data("sample_count", "random", callback_metrics)

        tracker.set_metric_meta(
            "sample_count", "random",
            MetricMeta(name="sample_count", metric_type="SAMPLE_TEXT"))

        summary_dict["sample_count"] = callback_metrics[0].value

    else:
        LOGGER.debug("callback: name {}, namespace {}, metrics_data {}".format(
            "sample_count", "stratified", callback_metrics))

        tracker.log_metric_data("sample_count", "stratified", callback_metrics)

        tracker.set_metric_meta(
            "sample_count", "stratified",
            MetricMeta(name="sample_count", metric_type="SAMPLE_TABLE"))

        tracker.log_metric_data("original_count", "stratified", other_metrics)

        tracker.set_metric_meta(
            "original_count", "stratified",
            MetricMeta(name="original_count", metric_type="SAMPLE_TABLE"))

        summary_dict["sample_count"] = {}
        for sample_metric in callback_metrics:
            summary_dict["sample_count"][
                sample_metric.key] = sample_metric.value

        summary_dict["original_count"] = {}
        for sample_metric in other_metrics:
            summary_dict["original_count"][
                sample_metric.key] = sample_metric.value
    def fit(self, data_inst, valid_inst=None):

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()
        self.tree_dim = 1

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = self.label_alignment()
            LOGGER.debug('label mapping is {}'.format(label_mapping))
            self.tree_dim = len(label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.debug('begin to fit a boosting tree')
        for epoch_idx in range(self.num_trees):

            for t_idx in range(self.tree_dim):
                valid_feature = self.sample_valid_feature()
                self.send_valid_features(valid_feature, epoch_idx, t_idx)
                flow_id = self.generate_flowid(epoch_idx, t_idx)
                new_tree = HomoDecisionTreeArbiter(self.tree_param,
                                                   valid_feature=valid_feature,
                                                   epoch_idx=epoch_idx,
                                                   flow_id=flow_id,
                                                   tree_idx=t_idx)
                new_tree.fit()

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.global_loss_history.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.global_loss_history)}))

        LOGGER.debug('fitting h**o decision tree done')
Ejemplo n.º 23
0
    def __save_contingency_matrix(self, metric, metric_res, metric_name, metric_namespace):

        result_array, unique_predicted_label, unique_true_label = metric_res
        true_labels = list(map(int, unique_true_label))
        predicted_label = list(map(int, unique_predicted_label))
        result_table = []
        for l_ in result_array:
            result_table.append(list(map(int, l_)))

        extra_metas = {'true_labels': true_labels, 'predicted_labels': predicted_label, 'result_table': result_table}

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 24
0
    def callback_dbi(self, iter_num, dbi):
        metric_meta = MetricMeta(name='train',
                                 metric_type="DBI",
                                 extra_metas={
                                     "unit_name": "iters",
                                 })

        self.callback_meta(metric_name='DBI',
                           metric_namespace='train',
                           metric_meta=metric_meta)
        self.callback_metric(metric_name='DBI',
                             metric_namespace='train',
                             metric_data=[Metric(iter_num, dbi)])
Ejemplo n.º 25
0
    def fit(self, data_inst, validate_data=None):

        # init aggregator
        self.aggregator = HomoBoostArbiterAggregator()
        self.binning_obj = HomoFeatureBinningServer()

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.info('begin to fit a boosting tree')
        for epoch_idx in range(self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_booster(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        self.set_summary(self.generate_summary())
Ejemplo n.º 26
0
    def callback_loss(self, iter_num, loss):
        # noinspection PyTypeChecker
        metric_meta = MetricMeta(name='train',
                                 metric_type="LOSS",
                                 extra_metas={
                                     "unit_name": "iters",
                                 })

        self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta)
        self.callback_metric(metric_name='loss',
                             metric_namespace='train',
                             metric_data=[Metric(iter_num, loss)])

        self._summary["loss_history"].append(loss)
Ejemplo n.º 27
0
    def __save_psi_table(self, metric, metric_res, metric_name, metric_namespace):

        psi_scores, total_psi, expected_interval, expected_percentage, actual_interval, actual_percentage, \
        train_pos_perc, validate_pos_perc, intervals = metric_res[1]

        extra_metas = {'psi_scores': list(np.round(psi_scores, self.round_num)),
                       'total_psi': round(total_psi, self.round_num),
                       'expected_interval': list(expected_interval),
                       'expected_percentage': list(expected_percentage), 'actual_interval': list(actual_interval),
                       'actual_percentage': list(actual_percentage), 'intervals': list(intervals),
                       'train_pos_perc': train_pos_perc, 'validate_pos_perc': validate_pos_perc
                       }

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 28
0
    def __save_distance_measure(self, metric, metric_res: dict, metric_name, metric_namespace):

        extra_metas = {}
        cluster_index = [k for k in metric_res.keys()]
        radius, neareast_idx = [], []
        for k in metric_res:
            radius.append(metric_res[k][0])
            neareast_idx.append(metric_res[k][1])

        extra_metas['cluster_index'] = cluster_index
        extra_metas['radius'] = radius
        extra_metas['nearest_idx'] = neareast_idx

        self.tracker.set_metric_meta(metric_namespace, metric_name,
                                     MetricMeta(name=metric_name, metric_type=metric.upper(), extra_metas=extra_metas))
Ejemplo n.º 29
0
    def fit(self, data):
        self.__init_intersect_method()

        if self.model_param.repeated_id_process:
            if self.model_param.intersect_cache_param.use_cache is True and self.model_param.intersect_method == consts.RSA:
                raise ValueError(
                    "Not support cache module while repeated id process.")

            if len(
                    self.host_party_id_list
            ) > 1 and self.model_param.repeated_id_owner != consts.GUEST:
                raise ValueError(
                    "While multi-host, repeated_id_owner should be guest.")

            proc_obj = RepeatedIDIntersect(
                repeated_id_owner=self.model_param.repeated_id_owner,
                role=self.role)
            data = proc_obj.run(data=data)

        if self.model_param.allow_info_share:
            if self.model_param.intersect_method == consts.RSA and self.model_param.info_owner == consts.GUEST \
                    or self.model_param.intersect_method == consts.RAW and self.model_param.join_role == self.model_param.info_owner:
                self.model_param.sync_intersect_ids = False

        self.intersect_ids = self.intersection_obj.run(data)

        if self.model_param.allow_info_share:
            self.intersect_ids = self.__share_info(data)

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num * 1.0 / data.count()

        self.set_summary(self.get_model_summary())

        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=[
                                 Metric("intersect_count", self.intersect_num),
                                 Metric("intersect_rate", self.intersect_rate)
                             ])
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=MetricMeta(
                                         name=self.metric_name,
                                         metric_type=self.metric_type))
Ejemplo n.º 30
0
Archivo: scale.py Proyecto: zpskt/FATE
    def transform(self, data, fit_config=None):
        """
        Transform input data using scale with fit results
        Parameters
        ----------
        data: data_instance, input data
        fit_config: list, the fit results information of scale

        Returns
        ----------
        transform_data:data_instance, data after transform
        """
        LOGGER.info("Start scale data transform ...")

        if self.model_param.method == consts.MINMAXSCALE:
            self.scale_obj = MinMaxScale(self.model_param)
        elif self.model_param.method == consts.STANDARDSCALE:
            self.scale_obj = StandardScale(self.model_param)
            self.scale_obj.set_param(self.mean, self.std)
        else:
            LOGGER.info(
                "DataTransform method is {}, do nothing and return!".format(
                    self.model_param.method))

        if self.scale_obj:
            self.scale_obj.header = self.header
            self.scale_obj.scale_column_idx = self.scale_column_idx
            self.scale_obj.set_column_range(self.column_max_value,
                                            self.column_min_value)
            transform_data = self.scale_obj.transform(data)
            transform_data.schema = data.schema

            self.callback_meta(
                metric_name="scale",
                metric_namespace="train",
                metric_meta=MetricMeta(
                    name="scale",
                    metric_type="SCALE",
                    extra_metas={"method": self.model_param.method}))

        else:
            transform_data = data

        LOGGER.info("End transform data.")

        return transform_data