Esempio n. 1
0
    def predict(self, guest_data):
        LOGGER.info("@ start guest predict")
        features, labels, instances_indexes = convert_instance_table_to_array(
            guest_data)
        guest_x = np.squeeze(features)
        guest_y = np.expand_dims(labels, axis=1)
        LOGGER.debug("guest_x, guest_y: " + str(guest_x.shape) + ", " +
                     str(guest_y.shape))

        host_prob = self._do_get(
            name=self.transfer_variable.host_prob.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.host_prob),
            idx=-1)[0]

        self.guest_model.set_batch(guest_x, guest_y)
        pred_prob = self.guest_model.predict(host_prob)
        LOGGER.debug("pred_prob: " + str(pred_prob.shape))

        self._do_remote(pred_prob,
                        name=self.transfer_variable.pred_prob.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.pred_prob),
                        role=consts.HOST,
                        idx=-1)
        return None
Esempio n. 2
0
    def predict(self, host_data, predict_param):
        LOGGER.info("@ start host predict")
        features, labels, instances_indexes = convert_instance_table_to_array(
            host_data)
        host_x = np.squeeze(features)
        LOGGER.debug("host_x: " + str(host_x.shape))

        host_prob = self.host_model.predict(host_x)
        self._do_remote(host_prob,
                        name=self.transfer_variable.host_prob.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_prob),
                        role=consts.GUEST,
                        idx=-1)

        pred_prob = self._do_get(
            name=self.transfer_variable.pred_prob.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.pred_prob),
            idx=-1)[0]

        pred_prob = np.squeeze(pred_prob)
        LOGGER.debug("pred_prob: " + str(pred_prob.shape))

        pred_prob_table = create_table(pred_prob, instances_indexes)
        actual_label_table = create_table(labels, instances_indexes)
        pred_label_table = self.classified(pred_prob_table,
                                           predict_param.threshold)
        if predict_param.with_proba:
            predict_result = actual_label_table.join(
                pred_prob_table, lambda label, prob:
                (label if label > 0 else 0, prob))
            predict_result = predict_result.join(pred_label_table, lambda x, y:
                                                 (x[0], x[1], y))
        else:
            predict_result = actual_label_table.join(
                pred_label_table, lambda a_label, p_label:
                (a_label, None, p_label))
        return predict_result