예제 #1
0
    def _split_column(self, inputs):
        """
            Inner function to sample part of the column of the input dataset
        """
        input_dataset_shape = inputs[self._main_resource_id].shape
        # find target column, we should not split these column
        target_column = DataMetadata.list_columns_with_semantic_types(
            self._training_inputs.metadata,
            ['https://metadata.datadrivendiscovery.org/types/TrueTarget'],
            at=(self._main_resource_id, ))
        if not target_column:
            self._logger.warn("No target column found from the input dataset.")
        index_column = DataMetadata.get_index_columns(
            self._training_inputs.metadata, at=(self._main_resource_id, ))
        if not index_column:
            self._logger.warn("No index column found from the input dataset.")

        outputs = copy.copy(inputs)
        if self._status is Status.TRAIN:
            # check again on the amount of the attributes column only
            # we only need to sample when attribute column numbers are larger than threshould
            attribute_column_length = (input_dataset_shape[1] -
                                       len(index_column) - len(target_column))
            if attribute_column_length > self._threshold_column_length:
                attribute_column = set(range(input_dataset_shape[1]))
                for each_target_column in target_column:
                    attribute_column.remove(each_target_column)
                for each_index_column in index_column:
                    attribute_column.remove(each_index_column)

                # generate the remained column index randomly and sort it
                self._column_remained = random.sample(
                    attribute_column, self._threshold_column_length)
                self._column_remained.extend(target_column)
                self._column_remained.extend(index_column)
                self._column_remained.sort()

        if len(self._column_remained) > 0:
            # Just to make sure.
            outputs.metadata = copy.deepcopy(inputs.metadata)
            outputs[self._main_resource_id] = inputs[
                self._main_resource_id].iloc[:, self._column_remained]
            outputs.metadata = self._select_columns_metadata(
                outputs.metadata, self._main_resource_id,
                self._column_remained)

        return outputs
예제 #2
0
def combine_columns_metadata(
    inputs: metadata_base.DataMetadata,
    column_indices: typing.Sequence[int],
    columns_list: typing.Sequence[metadata_base.DataMetadata],
    *,
    return_result: str,
    add_index_columns: bool,
) -> metadata_base.DataMetadata:
    """
    Analogous to ``combine_columns`` but operates only on metadata.
    """

    if return_result == 'append':
        outputs = inputs
        for columns in columns_list:
            outputs = outputs.append_columns(columns)

    elif return_result == 'replace':
        if not column_indices:
            return combine_columns_metadata(
                inputs,
                column_indices,
                columns_list,
                return_result='append',
                add_index_columns=add_index_columns)

        outputs = inputs

        columns_replaced = 0
        for columns in columns_list:
            columns_length = columns.query_field(
                (metadata_base.ALL_ELEMENTS, ), 'dimension')['length']
            if columns_replaced < len(column_indices):
                # It is OK if the slice of "column_indices" is shorter than "columns", Only those columns
                # listed in the slice will be replaced and others appended after the last replaced column.
                outputs = outputs.replace_columns(
                    columns, column_indices[columns_replaced:columns_replaced +
                                            columns_length])
            else:
                # We insert the rest of columns after the last columns we replaced. We know that "column_indices"
                # is non-empty and that the last item of "column_indices" points ot the last column we replaced
                # for those listed in "column_indices". We replaced more columns though, so we have to add the
                # difference, and then add 1 to insert after the last column.
                outputs = outputs.insert_columns(
                    columns, column_indices[-1] +
                    (columns_replaced - len(column_indices)) + 1)
            columns_replaced += columns_length

        if columns_replaced < len(column_indices):
            outputs = outputs.remove_columns(
                column_indices[columns_replaced:len(column_indices)])

    elif return_result == 'new':
        if not any(
                columns_metadata.query_field(
                    (metadata_base.ALL_ELEMENTS, ), 'dimension')['length']
                for columns_metadata in columns_list):
            raise ValueError("No columns produced.")

        outputs = columns_list[0]
        for columns in columns_list[1:]:
            outputs = outputs.append_columns(columns)

        if add_index_columns:
            inputs_index_columns = inputs.get_index_columns()
            outputs_index_columns = outputs.get_index_columns()

            if inputs_index_columns and not outputs_index_columns:
                # Add index columns at the beginning.
                outputs = inputs.select_columns(
                    inputs_index_columns).append_columns(
                        outputs, use_right_metadata=True)

    else:
        raise exceptions.InvalidArgumentValueError(
            "\"return_result\" has an invalid value: {return_result}".format(
                return_result=return_result))

    return outputs