示例#1
0
    def __test(host_sample_indexes, guest_sample_indexes, before_overlap_indexes, before_host_nonoverlap_indexes):
        host_x_dict = {}
        host_label_dict = {}
        np.random.seed(100)
        for i in host_sample_indexes:
            host_x_dict[i] = np.random.rand(1, 3)
            host_label_dict[i] = np.random.randint(0, 2)

        overlap_samples, nonoverlap_samples = fetch_overlap_data(host_x_dict, before_overlap_indexes,
                                                                 before_host_nonoverlap_indexes)
        overlap_labels, nonoverlap_labels = fetch_overlap_data(host_label_dict, before_overlap_indexes,
                                                               before_host_nonoverlap_indexes)

        overlap_samples = np.squeeze(overlap_samples)
        nonoverlap_samples = np.squeeze(nonoverlap_samples)
        overlap_labels = np.expand_dims(overlap_labels, axis=1)
        nonoverlap_labels = np.expand_dims(nonoverlap_labels, axis=1)

        host_x, overlap_indexes, non_overlap_indexes, host_label = overlapping_samples_converter(host_x_dict,
                                                                                                 host_sample_indexes,
                                                                                                 guest_sample_indexes,
                                                                                                 host_label_dict)

        after_conversion_overlap_samples = host_x[overlap_indexes]
        after_conversion_nonoverlap_samples = host_x[non_overlap_indexes]
        after_conversion_overlap_labels = host_label[overlap_indexes]
        after_conversion_nonoverlap_labels = host_label[non_overlap_indexes]

        assert_matrix(overlap_samples, after_conversion_overlap_samples)
        assert_matrix(nonoverlap_samples, after_conversion_nonoverlap_samples)
        assert_matrix(overlap_labels, after_conversion_overlap_labels)
        assert_matrix(nonoverlap_labels, after_conversion_nonoverlap_labels)
示例#2
0
    def prepare_data(self, guest_data):
        LOGGER.info("@ start guest prepare_data")
        guest_features_dict, guest_label_dict, guest_sample_indexes = convert_instance_table_to_dict(
            guest_data)
        guest_sample_indexes = np.array(guest_sample_indexes)
        LOGGER.debug("@ send guest_sample_indexes shape" +
                     str(guest_sample_indexes.shape))
        self._do_remote(guest_sample_indexes,
                        name=self.transfer_variable.guest_sample_indexes.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.guest_sample_indexes),
                        role=consts.HOST,
                        idx=-1)
        host_sample_indexes = self._do_get(
            name=self.transfer_variable.host_sample_indexes.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.host_sample_indexes),
            idx=-1)[0]

        LOGGER.debug("@ receive host_sample_indexes len" +
                     str(len(host_sample_indexes)))
        guest_features, overlap_indexes, non_overlap_indexes, guest_label = overlapping_samples_converter(
            guest_features_dict, guest_sample_indexes, host_sample_indexes,
            guest_label_dict)
        return guest_features, overlap_indexes, non_overlap_indexes, guest_label
示例#3
0
    def prepare_data(self, host_data):
        LOGGER.info("@ start host prepare data")
        host_features_dict, _, host_sample_indexes = convert_instance_table_to_dict(
            host_data)
        host_sample_indexes = np.array(host_sample_indexes)

        self._do_remote(host_sample_indexes,
                        name=self.transfer_variable.host_sample_indexes.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_sample_indexes),
                        role=consts.GUEST,
                        idx=-1)

        guest_sample_indexes = self._do_get(
            name=self.transfer_variable.guest_sample_indexes.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_sample_indexes),
            idx=-1)[0]

        host_features, overlap_indexes, _ = overlapping_samples_converter(
            host_features_dict, host_sample_indexes, guest_sample_indexes)
        return host_features, overlap_indexes