コード例 #1
0
ファイル: base.py プロジェクト: sdv-dev/SDV
class BaseTimeseriesModel:
    """Base class for timeseries models.

    Args:
        field_names (list[str]):
            List of names of the fields that need to be modeled
            and included in the generated output data. Any additional
            fields found in the data will be ignored and will not be
            included in the generated output.
            If ``None``, all the fields found in the data are used.
        field_types (dict[str, dict]):
            Dictinary specifying the data types and subtypes
            of the fields that will be modeled. Field types and subtypes
            combinations must be compatible with the SDV Metadata Schema.
        anonymize_fields (dict[str, str]):
            Dict specifying which fields to anonymize and what faker
            category they belong to.
        primary_key (str):
            Name of the field which is the primary key of the table.
        entity_columns (list[str]):
            Names of the columns which identify different time series
            sequences. These will be used to group the data in separated
            training examples.
        context_columns (list[str]):
            The columns in the dataframe which are constant within each
            group/entity. These columns will be provided at sampling time
            (i.e. the samples will be conditioned on the context variables).
        segment_size (int, pd.Timedelta or str):
            If specified, cut each training sequence in several segments of
            the indicated size. The size can either can passed as an integer
            value, which will interpreted as the number of data points to
            put on each segment, or as a pd.Timedelta (or equivalent str
            representation), which will be interpreted as the segment length
            in time. Timedelta segment sizes can only be used with sequence
            indexes of type datetime.
        sequence_index (str):
            Name of the column that acts as the order index of each
            sequence. The sequence index column can be of any type that can
            be sorted, such as integer values or datetimes.
        context_model (str or sdv.tabular.BaseTabularModel):
            Model to use to sample the context rows. It can be passed as a
            a string, which must be one of the following:

            * `gaussian_copula` (default): Use a GaussianCopula model.

            Alternatively, a preconfigured Tabular model instance can be
            passed.

        table_metadata (dict or metadata.Table):
            Table metadata instance or dict representation.
            If given alongside any other metadata-related arguments, an
            exception will be raised.
            If not given at all, it will be built using the other
            arguments or learned from the data.
    """

    _DTYPE_TRANSFORMERS = {
        'i': None,
        'f': None,
        'M': rdt.transformers.DatetimeTransformer(strip_constant=True),
        'b': None,
        'O': None,
    }
    _CONTEXT_MODELS = {
        'gaussian_copula': (GaussianCopula, {'categorical_transformer': 'categorical_fuzzy'})
    }

    _metadata = None

    def __init__(self, field_names=None, field_types=None, anonymize_fields=None,
                 primary_key=None, entity_columns=None, context_columns=None,
                 sequence_index=None, segment_size=None, context_model=None,
                 table_metadata=None):
        if table_metadata is None:
            self._metadata = Table(
                field_names=field_names,
                primary_key=primary_key,
                field_types=field_types,
                anonymize_fields=anonymize_fields,
                dtype_transformers=self._DTYPE_TRANSFORMERS,
                sequence_index=sequence_index,
                entity_columns=entity_columns,
                context_columns=context_columns,
            )
            self._metadata_fitted = False
        else:
            null_args = (
                field_names,
                primary_key,
                field_types,
                anonymize_fields,
                sequence_index,
                entity_columns,
                context_columns
            )
            for arg in null_args:
                if arg:
                    raise ValueError(
                        'If table_metadata is given {} must be None'.format(arg.__name__))

            if isinstance(table_metadata, dict):
                table_metadata = Table.from_dict(
                    table_metadata,
                    dtype_transformers=self._DTYPE_TRANSFORMERS,
                )

            self._metadata = table_metadata
            self._metadata_fitted = table_metadata.fitted

        # Validate arguments
        if segment_size is not None and not isinstance(segment_size, int):
            if sequence_index is None:
                raise TypeError(
                    '`segment_size` must be of type `int` if '
                    'no `sequence_index` is given.'
                )

            segment_size = pd.to_timedelta(segment_size)

        self._context_columns = self._metadata._context_columns
        self._entity_columns = self._metadata._entity_columns
        self._sequence_index = self._metadata._sequence_index
        self._segment_size = segment_size

        context_model = context_model or 'gaussian_copula'
        if isinstance(context_model, str):
            context_model = self._CONTEXT_MODELS[context_model]

        self._context_model_template = context_model

    def _fit(self, timeseries_data):
        raise NotImplementedError()

    def _fit_context_model(self, transformed):
        template = self._context_model_template
        default_kwargs = {
            'primary_key': self._entity_columns,
            'field_types': {
                name: meta
                for name, meta in self._metadata.get_fields().items()
                if name in self._entity_columns
            }
        }
        if isinstance(template, tuple):
            context_model_class, context_model_kwargs = copy.deepcopy(template)
            if 'primary_key' not in context_model_kwargs:
                context_model_kwargs['primary_key'] = self._entity_columns
                for keyword, argument in default_kwargs.items():
                    if keyword not in context_model_kwargs:
                        context_model_kwargs[keyword] = argument

            self._context_model = context_model_class(**context_model_kwargs)
        elif isinstance(template, type):
            self._context_model = template(**default_kwargs)
        else:
            self._context_model = copy.deepcopy(template)

        LOGGER.debug('Fitting context model %s', self._context_model.__class__.__name__)
        if self._context_columns:
            context = transformed[self._entity_columns + self._context_columns]
        else:
            context = transformed[self._entity_columns].copy()
            # Add constant column to allow modeling
            context[str(uuid.uuid4())] = 0

        context = context.groupby(self._entity_columns).first().reset_index()
        self._context_model.fit(context)

    def fit(self, timeseries_data):
        """Fit this model to the data.

        Args:
            timseries_data (pandas.DataFrame):
                pandas.DataFrame containing both the sequences,
                the entity columns and the context columns.
        """
        LOGGER.debug('Fitting %s to table %s; shape: %s', self.__class__.__name__,
                     self._metadata.name, timeseries_data.shape)
        if not self._metadata_fitted:
            self._metadata.fit(timeseries_data)

        LOGGER.debug('Transforming table %s; shape: %s',
                     self._metadata.name, timeseries_data.shape)
        transformed = self._metadata.transform(timeseries_data)

        for column in self._entity_columns:
            transformed[column] = timeseries_data[column]

        if self._entity_columns:
            self._fit_context_model(transformed)

        LOGGER.debug('Fitting %s model to table %s', self.__class__.__name__, self._metadata.name)
        self._fit(transformed)

    def get_metadata(self):
        """Get metadata about the table.

        This will return an ``sdv.metadata.Table`` object containing
        the information about the data that this model has learned.

        This Table metadata will contain some common information,
        such as field names and data types, as well as additional
        information that each Sub-class might add, such as the
        observed data field distributions and their parameters.

        Returns:
            sdv.metadata.Table:
                Table metadata.
        """
        return self._metadata

    def _sample(self, context=None, sequence_length=None):
        raise NotImplementedError()

    def sample(self, num_sequences=None, context=None, sequence_length=None):
        """Sample new sequences.

        Args:
            num_sequences (int):
                Number of sequences to sample. If context is
                passed, this is ignored. If not given, the
                same number of sequences as in the original
                timeseries_data is sampled.
            context (pandas.DataFrame):
                Context values to use when generating the sequences.
                If not passed, the context values will be sampled
                using the specified tabular model.
            sequence_length (int):
                If passed, sample sequences of this length. If not
                given, the sequence length will be sampled from
                the model.

        Returns:
            pandas.DataFrame:
                Table containing the sampled sequences in the same
                format as that he training data had.
        """
        if not self._entity_columns:
            if context is not None:
                raise TypeError('If there are no entity_columns, context must be None')

            context = pd.DataFrame(index=range(num_sequences or 1))
        elif context is None:
            context = self._context_model.sample(num_sequences, output_file_path='disable')
            for column in self._entity_columns or []:
                if column not in context:
                    context[column] = range(len(context))

        sampled = self._sample(context, sequence_length)
        return self._metadata.reverse_transform(sampled)

    def save(self, path):
        """Save this model instance to the given path using pickle.

        Args:
            path (str):
                Path where the SDV instance will be serialized.
        """
        self._package_versions = get_package_versions(getattr(self, '_model', None))

        with open(path, 'wb') as output:
            pickle.dump(self, output)

    @classmethod
    def load(cls, path):
        """Load a TabularModel instance from a given path.

        Args:
            path (str):
                Path from which to load the instance.

        Returns:
            TabularModel:
                The loaded tabular model.
        """
        with open(path, 'rb') as f:
            model = pickle.load(f)
            throw_version_mismatch_warning(getattr(model, '_package_versions', None))

            return model
コード例 #2
0
ファイル: base.py プロジェクト: surajitdb/SDV
class BaseTabularModel:
    """Base class for all the tabular models.

    The ``BaseTabularModel`` class defines the common API that all the
    TabularModels need to implement, as well as common functionality.

    Args:
        field_names (list[str]):
            List of names of the fields that need to be modeled
            and included in the generated output data. Any additional
            fields found in the data will be ignored and will not be
            included in the generated output.
            If ``None``, all the fields found in the data are used.
        field_types (dict[str, dict]):
            Dictinary specifying the data types and subtypes
            of the fields that will be modeled. Field types and subtypes
            combinations must be compatible with the SDV Metadata Schema.
        field_transformers (dict[str, str]):
            Dictinary specifying which transformers to use for each field.
            Available transformers are:

                * ``integer``: Uses a ``NumericalTransformer`` of dtype ``int``.
                * ``float``: Uses a ``NumericalTransformer`` of dtype ``float``.
                * ``categorical``: Uses a ``CategoricalTransformer`` without gaussian noise.
                * ``categorical_fuzzy``: Uses a ``CategoricalTransformer`` adding gaussian noise.
                * ``one_hot_encoding``: Uses a ``OneHotEncodingTransformer``.
                * ``label_encoding``: Uses a ``LabelEncodingTransformer``.
                * ``boolean``: Uses a ``BooleanTransformer``.
                * ``datetime``: Uses a ``DatetimeTransformer``.

        anonymize_fields (dict[str, str]):
            Dict specifying which fields to anonymize and what faker
            category they belong to.
        primary_key (str):
            Name of the field which is the primary key of the table.
        constraints (list[Constraint, dict]):
            List of Constraint objects or dicts.
        table_metadata (dict or metadata.Table):
            Table metadata instance or dict representation.
            If given alongside any other metadata-related arguments, an
            exception will be raised.
            If not given at all, it will be built using the other
            arguments or learned from the data.
    """

    _DTYPE_TRANSFORMERS = None

    _metadata = None

    def __init__(self,
                 field_names=None,
                 field_types=None,
                 field_transformers=None,
                 anonymize_fields=None,
                 primary_key=None,
                 constraints=None,
                 table_metadata=None):
        if table_metadata is None:
            self._metadata = Table(
                field_names=field_names,
                primary_key=primary_key,
                field_types=field_types,
                field_transformers=field_transformers,
                anonymize_fields=anonymize_fields,
                constraints=constraints,
                dtype_transformers=self._DTYPE_TRANSFORMERS,
            )
            self._metadata_fitted = False
        else:
            for arg in (field_names, primary_key, field_types,
                        anonymize_fields, constraints):
                if arg:
                    raise ValueError(
                        'If table_metadata is given {} must be None'.format(
                            arg.__name__))

            if isinstance(table_metadata, dict):
                table_metadata = Table.from_dict(table_metadata)

            table_metadata._dtype_transformers.update(self._DTYPE_TRANSFORMERS)

            self._metadata = table_metadata
            self._metadata_fitted = table_metadata.fitted

    def fit(self, data):
        """Fit this model to the data.

        If the table metadata has not been given, learn it from the data.

        Args:
            data (pandas.DataFrame or str):
                Data to fit the model to. It can be passed as a
                ``pandas.DataFrame`` or as an ``str``.
                If an ``str`` is passed, it is assumed to be
                the path to a CSV file which can be loaded using
                ``pandas.read_csv``.
        """
        LOGGER.debug('Fitting %s to table %s; shape: %s',
                     self.__class__.__name__, self._metadata.name, data.shape)
        if not self._metadata_fitted:
            self._metadata.fit(data)

        self._num_rows = len(data)

        LOGGER.debug('Transforming table %s; shape: %s', self._metadata.name,
                     data.shape)
        transformed = self._metadata.transform(data)

        if self._metadata.get_dtypes(ids=False):
            LOGGER.debug('Fitting %s model to table %s',
                         self.__class__.__name__, self._metadata.name)
            self._fit(transformed)

    def get_metadata(self):
        """Get metadata about the table.

        This will return an ``sdv.metadata.Table`` object containing
        the information about the data that this model has learned.

        This Table metadata will contain some common information,
        such as field names and data types, as well as additional
        information that each Sub-class might add, such as the
        observed data field distributions and their parameters.

        Returns:
            sdv.metadata.Table:
                Table metadata.
        """
        return self._metadata

    @staticmethod
    def _filter_conditions(sampled, conditions, float_rtol):
        """Filter the sampled rows that match the conditions.

        If condition columns are float values, consider a match anything that
        is closer than the given ``float_rtol`` and then make the value exact.

        Args:
            sampled (pandas.DataFrame):
                The sampled rows, reverse transformed.
            conditions (dict):
                The dictionary of conditioning values.
            float_rtol (float):
                Maximum tolerance when considering a float match.

        Returns:
            pandas.DataFrame:
                Rows from the sampled data that match the conditions.
        """
        for column, value in conditions.items():
            column_values = sampled[column]
            if column_values.dtype.kind == 'f':
                distance = value * float_rtol
                sampled = sampled[np.abs(column_values - value) < distance]
                sampled[column] = value
            else:
                sampled = sampled[column_values == value]

        return sampled

    def _sample_rows(self,
                     num_rows,
                     conditions=None,
                     transformed_conditions=None,
                     float_rtol=0.1,
                     previous_rows=None):
        """Sample rows with the given conditions.

        Input conditions is taken both in the raw input format, which will be used
        for filtering during the reject-sampling loop, and already transformed
        to the model format, which will be passed down to the model if it supports
        conditional sampling natively.

        If condition columns are float values, consider a match anything that
        is closer than the given ``float_rtol`` and then make the value exact.

        If the model does not have any data columns, the result of this call
        is a dataframe of the requested length with no columns in it.

        Args:
            num_rows (int):
                Number of rows to sample.
            conditions (dict):
                The dictionary of conditioning values in the original format.
            transformed_conditions (dict):
                The dictionary of conditioning values transformed to the model format.
            float_rtol (float):
                Maximum tolerance when considering a float match.
            previous_rows (pandas.DataFrame):
                Valid rows sampled in the previous iterations.

        Returns:
            tuple:
                * pandas.DataFrame:
                    Rows from the sampled data that match the conditions.
                * int:
                    Number of rows that are considered valid.
        """
        if self._metadata.get_dtypes(ids=False):
            if conditions is None:
                sampled = self._sample(num_rows)
            else:
                try:
                    sampled = self._sample(num_rows, transformed_conditions)
                except NotImplementedError:
                    sampled = self._sample(num_rows)

            sampled = self._metadata.reverse_transform(sampled)

            if previous_rows is not None:
                sampled = previous_rows.append(sampled)

            sampled = self._metadata.filter_valid(sampled)

            if conditions is not None:
                sampled = self._filter_conditions(sampled, conditions,
                                                  float_rtol)

            num_valid = len(sampled)

            return sampled, num_valid

        else:
            sampled = pd.DataFrame(index=range(num_rows))
            sampled = self._metadata.reverse_transform(sampled)
            return sampled, num_rows

    def _sample_batch(self,
                      num_rows=None,
                      max_retries=100,
                      max_rows_multiplier=10,
                      conditions=None,
                      transformed_conditions=None,
                      float_rtol=0.01):
        """Sample a batch of rows with the given conditions.

        This will enter a reject-sampling loop in which rows will be sampled until
        all of them are valid and match the requested conditions. If `max_retries`
        is exceeded, it will return as many rows as it has sampled, which may be less
        than the target number of rows.

        Input conditions is taken both in the raw input format, which will be used
        for filtering during the reject-sampling loop, and already transformed
        to the model format, which will be passed down to the model if it supports
        conditional sampling natively.

        If condition columns are float values, consider a match anything that is
        relatively closer than the given ``float_rtol`` and then make the value exact.

        If the model does not have any data columns, the result of this call
        is a dataframe of the requested length with no columns in it.

        Args:
            num_rows (int):
                Number of rows to sample. If not given the model
                will generate as many rows as there were in the
                data passed to the ``fit`` method.
            max_retries (int):
                Number of times to retry sampling discarded rows.
                Defaults to 100.
            max_rows_multiplier (int):
                Multiplier to use when computing the maximum number of rows
                that can be sampled during the reject-sampling loop.
                The maximum number of rows that are sampled at each iteration
                will be equal to this number multiplied by the requested num_rows.
                Defaults to 10.
            conditions (dict):
                The dictionary of conditioning values in the original input format.
            transformed_conditions (dict):
                The dictionary of conditioning values transformed to the model format.
            float_rtol (float):
                Maximum tolerance when considering a float match.

        Returns:
            pandas.DataFrame:
                Sampled data.
        """
        sampled, num_valid = self._sample_rows(num_rows, conditions,
                                               transformed_conditions,
                                               float_rtol)

        counter = 0
        total_sampled = num_rows
        while num_valid < num_rows:
            if counter >= max_retries:
                break

            remaining = num_rows - num_valid
            valid_probability = (num_valid + 1) / (total_sampled + 1)
            max_rows = num_rows * max_rows_multiplier
            num_to_sample = min(int(remaining / valid_probability), max_rows)
            total_sampled += num_to_sample

            LOGGER.info('%s valid rows remaining. Resampling %s rows',
                        remaining, num_to_sample)
            sampled, num_valid = self._sample_rows(num_to_sample, conditions,
                                                   transformed_conditions,
                                                   float_rtol, sampled)

            counter += 1

        return sampled.head(min(len(sampled), num_rows))

    def _make_conditions_df(self, conditions, num_rows):
        """Transform `conditions` into a dataframe.

        Args:
            conditions (pd.DataFrame, dict or pd.Series):
                If this is a dictionary/Series which maps column names to the column
                value, then this method generates `num_rows` samples, all of
                which are conditioned on the given variables. If this is a DataFrame,
                then it generates an output DataFrame such that each row in the output
                is sampled conditional on the corresponding row in the input.
            num_rows (int):
                Number of rows to sample. If a conditions dataframe is given, this must
                either be ``None`` or match the length of the ``conditions`` dataframe.

        Returns:
            pandas.DataFrame:
                `conditions` as a dataframe.
        """
        if isinstance(conditions, pd.Series):
            conditions = pd.DataFrame([conditions] * num_rows)

        elif isinstance(conditions, dict):
            try:
                conditions = pd.DataFrame(conditions)
            except ValueError:
                conditions = pd.DataFrame([conditions] * num_rows)

        elif not isinstance(conditions, pd.DataFrame):
            raise TypeError(
                '`conditions` must be a dataframe, a dictionary or a pandas series.'
            )

        elif num_rows is not None and len(conditions) != num_rows:
            raise ValueError(
                'If `conditions` is a `DataFrame`, `num_rows` must be `None` or match its lenght.'
            )

        return conditions.copy()

    def sample(self,
               num_rows=None,
               max_retries=100,
               max_rows_multiplier=10,
               conditions=None,
               float_rtol=0.01,
               graceful_reject_sampling=False):
        """Sample rows from this table.

        Args:
            num_rows (int):
                Number of rows to sample. If not given the model
                will generate as many rows as there were in the
                data passed to the ``fit`` method.
            max_retries (int):
                Number of times to retry sampling discarded rows.
                Defaults to 100.
            max_rows_multiplier (int):
                Multiplier to use when computing the maximum number of rows
                that can be sampled during the reject-sampling loop.
                The maximum number of rows that are sampled at each iteration
                will be equal to this number multiplied by the requested num_rows.
                Defaults to 10.
            conditions (pd.DataFrame, dict or pd.Series):
                If this is a dictionary/Series which maps column names to the column
                value, then this method generates `num_rows` samples, all of
                which are conditioned on the given variables. If this is a DataFrame,
                then it generates an output DataFrame such that each row in the output
                is sampled conditional on the corresponding row in the input.
            float_rtol (float):
                Maximum tolerance when considering a float match. This is the maximum
                relative distance at which a float value will be considered a match
                when performing reject-sampling based conditioning. Defaults to 0.01.
            graceful_reject_sampling (bool):
                If `False` raises a `ValueError` if not enough valid rows could be sampled
                within `max_retries` trials. If `True` prints a warning and returns
                as many rows as it was able to sample within `max_retries`. If no rows could
                be generated, raises a `ValueError`.
                Defaults to False.

        Returns:
            pandas.DataFrame:
                Sampled data.
        """
        if conditions is None:
            num_rows = num_rows or self._num_rows
            return self._sample_batch(num_rows, max_retries,
                                      max_rows_multiplier)

        # convert conditions to dataframe
        conditions = self._make_conditions_df(conditions, num_rows)

        # validate columns
        for column in conditions.columns:
            if column not in self._metadata.get_fields():
                raise ValueError(f'Invalid column name `{column}`')

            if len(self._metadata.transform(conditions[[column
                                                        ]]).columns) == 0:
                raise ValueError(
                    f'Conditioning on column `{column}` is not possible')

        transformed_conditions = self._metadata.transform(conditions)
        condition_columns = list(transformed_conditions.columns)
        transformed_conditions.index.name = '__condition_idx__'
        transformed_conditions.reset_index(inplace=True)
        grouped_conditions = transformed_conditions.groupby(condition_columns)

        # sample
        all_sampled_rows = list()

        for index, dataframe in grouped_conditions:
            if not isinstance(index, tuple):
                index = [index]

            condition = conditions.loc[dataframe['__condition_idx__'].iloc[0]]
            transformed_condition = dict(zip(condition_columns, index))
            sampled_rows = self._sample_batch(len(dataframe), max_retries,
                                              max_rows_multiplier, condition,
                                              transformed_condition,
                                              float_rtol)

            if len(sampled_rows) < len(dataframe):
                # Didn't get enough rows.
                if len(sampled_rows) == 0:
                    error = 'No valid rows could be generated with the given conditions.'
                    raise ValueError(error)

                elif not graceful_reject_sampling:
                    error = f'Could not get enough valid rows within {max_retries} trials.'
                    raise ValueError(error)

                else:
                    warn(f'Only {len(sampled_rows)} rows could '
                         f'be sampled within {max_retries} trials.')

            if len(sampled_rows) > 0:
                sampled_rows['__condition_idx__'] = \
                    dataframe['__condition_idx__'].values[:len(sampled_rows)]
                all_sampled_rows.append(sampled_rows)

        all_sampled_rows = pd.concat(all_sampled_rows)
        all_sampled_rows = all_sampled_rows.set_index('__condition_idx__')
        all_sampled_rows.index.name = conditions.index.name
        all_sampled_rows = all_sampled_rows.sort_index()

        return all_sampled_rows

    def _get_parameters(self):
        raise NonParametricError()

    def get_parameters(self):
        """Get the parameters learned from the data.

        The result is a flat dict (single level) which contains
        all the necessary parameters to be able to reproduce
        this model.

        Subclasses which are not parametric, such as DeepLearning
        based models, raise a NonParametricError indicating that
        this method is not supported for their implementation.

        Returns:
            parameters (dict):
                flat dict (single level) which contains all the
                necessary parameters to be able to reproduce
                this model.

        Raises:
            NonParametricError:
                If the model is not parametric or cannot be described
                using a simple dictionary.
        """
        if self._metadata.get_dtypes(ids=False):
            parameters = self._get_parameters()
        else:
            parameters = {}

        parameters['num_rows'] = self._num_rows
        return parameters

    def _set_parameters(self, parameters):
        raise NonParametricError()

    def set_parameters(self, parameters):
        """Regenerate a previously learned model from its parameters.

        Subclasses which are not parametric, such as DeepLearning
        based models, raise a NonParametricError indicating that
        this method is not supported for their implementation.

        Args:
            dict:
                Model parameters.

        Raises:
            NonParametricError:
                If the model is not parametric or cannot be described
                using a simple dictionary.
        """
        num_rows = parameters.pop('num_rows')
        self._num_rows = 0 if pd.isnull(num_rows) else max(
            0, int(round(num_rows)))

        if self._metadata.get_dtypes(ids=False):
            self._set_parameters(parameters)

    def save(self, path):
        """Save this model instance to the given path using pickle.

        Args:
            path (str):
                Path where the SDV instance will be serialized.
        """
        with open(path, 'wb') as output:
            pickle.dump(self, output)

    @classmethod
    def load(cls, path):
        """Load a TabularModel instance from a given path.

        Args:
            path (str):
                Path from which to load the instance.

        Returns:
            TabularModel:
                The loaded tabular model.
        """
        with open(path, 'rb') as f:
            return pickle.load(f)
コード例 #3
0
ファイル: base.py プロジェクト: sdv-dev/SDV
class BaseTabularModel:
    """Base class for all the tabular models.

    The ``BaseTabularModel`` class defines the common API that all the
    TabularModels need to implement, as well as common functionality.

    Args:
        field_names (list[str]):
            List of names of the fields that need to be modeled
            and included in the generated output data. Any additional
            fields found in the data will be ignored and will not be
            included in the generated output.
            If ``None``, all the fields found in the data are used.
        field_types (dict[str, dict]):
            Dictinary specifying the data types and subtypes
            of the fields that will be modeled. Field types and subtypes
            combinations must be compatible with the SDV Metadata Schema.
        field_transformers (dict[str, str]):
            Dictinary specifying which transformers to use for each field.
            Available transformers are:

                * ``integer``: Uses a ``NumericalTransformer`` of dtype ``int``.
                * ``float``: Uses a ``NumericalTransformer`` of dtype ``float``.
                * ``categorical``: Uses a ``CategoricalTransformer`` without gaussian noise.
                * ``categorical_fuzzy``: Uses a ``CategoricalTransformer`` adding gaussian noise.
                * ``one_hot_encoding``: Uses a ``OneHotEncodingTransformer``.
                * ``label_encoding``: Uses a ``LabelEncodingTransformer``.
                * ``boolean``: Uses a ``BooleanTransformer``.
                * ``datetime``: Uses a ``DatetimeTransformer``.

        anonymize_fields (dict[str, str]):
            Dict specifying which fields to anonymize and what faker
            category they belong to.
        primary_key (str):
            Name of the field which is the primary key of the table.
        constraints (list[Constraint, dict]):
            List of Constraint objects or dicts.
        table_metadata (dict or metadata.Table):
            Table metadata instance or dict representation.
            If given alongside any other metadata-related arguments, an
            exception will be raised.
            If not given at all, it will be built using the other
            arguments or learned from the data.
        rounding (int, str or None):
            Define rounding scheme for ``NumericalTransformer``. If set to an int, values
            will be rounded to that number of decimal places. If ``None``, values will not
            be rounded. If set to ``'auto'``, the transformer will round to the maximum number
            of decimal places detected in the fitted data. Defaults to ``'auto'``.
        min_value (int, str or None):
            Specify the minimum value the ``NumericalTransformer`` should use. If an integer
            is given, sampled data will be greater than or equal to it. If the string ``'auto'``
            is given, the minimum will be the minimum value seen in the fitted data. If ``None``
            is given, there won't be a minimum. Defaults to ``'auto'``.
        max_value (int, str or None):
            Specify the maximum value the ``NumericalTransformer`` should use. If an integer
            is given, sampled data will be less than or equal to it. If the string ``'auto'``
            is given, the maximum will be the maximum value seen in the fitted data. If ``None``
            is given, there won't be a maximum. Defaults to ``'auto'``.
    """

    _DTYPE_TRANSFORMERS = None

    _metadata = None

    def __init__(self,
                 field_names=None,
                 field_types=None,
                 field_transformers=None,
                 anonymize_fields=None,
                 primary_key=None,
                 constraints=None,
                 table_metadata=None,
                 rounding='auto',
                 min_value='auto',
                 max_value='auto'):
        if table_metadata is None:
            self._metadata = Table(field_names=field_names,
                                   primary_key=primary_key,
                                   field_types=field_types,
                                   field_transformers=field_transformers,
                                   anonymize_fields=anonymize_fields,
                                   constraints=constraints,
                                   dtype_transformers=self._DTYPE_TRANSFORMERS,
                                   rounding=rounding,
                                   min_value=min_value,
                                   max_value=max_value)
            self._metadata_fitted = False
        else:
            table_metadata = deepcopy(table_metadata)
            for arg in (field_names, primary_key, field_types,
                        anonymize_fields, constraints):
                if arg:
                    raise ValueError(
                        'If table_metadata is given {} must be None'.format(
                            arg.__name__))

            if isinstance(table_metadata, dict):
                table_metadata = Table.from_dict(table_metadata)

            table_metadata._dtype_transformers.update(self._DTYPE_TRANSFORMERS)

            self._metadata = table_metadata
            self._metadata_fitted = table_metadata.fitted

    def fit(self, data):
        """Fit this model to the data.

        If the table metadata has not been given, learn it from the data.

        Args:
            data (pandas.DataFrame or str):
                Data to fit the model to. It can be passed as a
                ``pandas.DataFrame`` or as an ``str``.
                If an ``str`` is passed, it is assumed to be
                the path to a CSV file which can be loaded using
                ``pandas.read_csv``.
        """
        if isinstance(data, pd.DataFrame):
            data = data.reset_index(drop=True)

        LOGGER.debug('Fitting %s to table %s; shape: %s',
                     self.__class__.__name__, self._metadata.name, data.shape)
        if not self._metadata_fitted:
            self._metadata.fit(data)

        self._num_rows = len(data)

        LOGGER.debug('Transforming table %s; shape: %s', self._metadata.name,
                     data.shape)
        transformed = self._metadata.transform(data)

        if self._metadata.get_dtypes(ids=False):
            LOGGER.debug('Fitting %s model to table %s',
                         self.__class__.__name__, self._metadata.name)
            self._fit(transformed)

    def get_metadata(self):
        """Get metadata about the table.

        This will return an ``sdv.metadata.Table`` object containing
        the information about the data that this model has learned.

        This Table metadata will contain some common information,
        such as field names and data types, as well as additional
        information that each Sub-class might add, such as the
        observed data field distributions and their parameters.

        Returns:
            sdv.metadata.Table:
                Table metadata.
        """
        return self._metadata

    @staticmethod
    def _filter_conditions(sampled, conditions, float_rtol):
        """Filter the sampled rows that match the conditions.

        If condition columns are float values, consider a match anything that
        is closer than the given ``float_rtol`` and then make the value exact.

        Args:
            sampled (pandas.DataFrame):
                The sampled rows, reverse transformed.
            conditions (dict):
                The dictionary of conditioning values.
            float_rtol (float):
                Maximum tolerance when considering a float match.

        Returns:
            pandas.DataFrame:
                Rows from the sampled data that match the conditions.
        """
        for column, value in conditions.items():
            column_values = sampled[column]
            if column_values.dtype.kind == 'f':
                distance = value * float_rtol
                sampled = sampled[np.abs(column_values - value) <= distance]
                sampled[column] = value
            else:
                sampled = sampled[column_values == value]

        return sampled

    def _sample_rows(self,
                     num_rows,
                     conditions=None,
                     transformed_conditions=None,
                     float_rtol=0.1,
                     previous_rows=None):
        """Sample rows with the given conditions.

        Input conditions is taken both in the raw input format, which will be used
        for filtering during the reject-sampling loop, and already transformed
        to the model format, which will be passed down to the model if it supports
        conditional sampling natively.

        If condition columns are float values, consider a match anything that
        is closer than the given ``float_rtol`` and then make the value exact.

        If the model does not have any data columns, the result of this call
        is a dataframe of the requested length with no columns in it.

        Args:
            num_rows (int):
                Number of rows to sample.
            conditions (dict):
                The dictionary of conditioning values in the original format.
            transformed_conditions (dict):
                The dictionary of conditioning values transformed to the model format.
            float_rtol (float):
                Maximum tolerance when considering a float match.
            previous_rows (pandas.DataFrame):
                Valid rows sampled in the previous iterations.

        Returns:
            tuple:
                * pandas.DataFrame:
                    Rows from the sampled data that match the conditions.
                * int:
                    Number of rows that are considered valid.
        """
        if self._metadata.get_dtypes(ids=False):
            if conditions is None:
                sampled = self._sample(num_rows)
            else:
                try:
                    sampled = self._sample(num_rows, transformed_conditions)
                except NotImplementedError:
                    sampled = self._sample(num_rows)

            sampled = self._metadata.reverse_transform(sampled)

            if previous_rows is not None:
                sampled = pd.concat([previous_rows, sampled],
                                    ignore_index=True)

            sampled = self._metadata.filter_valid(sampled)

            if conditions is not None:
                sampled = self._filter_conditions(sampled, conditions,
                                                  float_rtol)

            num_valid = len(sampled)

            return sampled, num_valid

        else:
            sampled = pd.DataFrame(index=range(num_rows))
            sampled = self._metadata.reverse_transform(sampled)
            return sampled, num_rows

    def _sample_batch(self,
                      num_rows=None,
                      max_tries=100,
                      batch_size_per_try=None,
                      conditions=None,
                      transformed_conditions=None,
                      float_rtol=0.01,
                      progress_bar=None,
                      output_file_path=None):
        """Sample a batch of rows with the given conditions.

        This will enter a reject-sampling loop in which rows will be sampled until
        all of them are valid and match the requested conditions. If `max_tries`
        is exceeded, it will return as many rows as it has sampled, which may be less
        than the target number of rows.

        Input conditions is taken both in the raw input format, which will be used
        for filtering during the reject-sampling loop, and already transformed
        to the model format, which will be passed down to the model if it supports
        conditional sampling natively.

        If condition columns are float values, consider a match anything that is
        relatively closer than the given ``float_rtol`` and then make the value exact.

        If the model does not have any data columns, the result of this call
        is a dataframe of the requested length with no columns in it.

        Args:
            num_rows (int):
                Number of rows to sample. If not given the model
                will generate as many rows as there were in the
                data passed to the ``fit`` method.
            max_tries (int):
                Number of times to try sampling discarded rows.
                Defaults to 100.
            batch_size_per_try (int):
                The batch size to use per attempt at sampling. Defaults to 10 times
                the number of rows.
            conditions (dict):
                The dictionary of conditioning values in the original input format.
            transformed_conditions (dict):
                The dictionary of conditioning values transformed to the model format.
            float_rtol (float):
                Maximum tolerance when considering a float match.
            progress_bar (tqdm.tqdm or None):
                The progress bar to update when sampling. If None, a new tqdm progress
                bar will be created.
            output_file_path (str or None):
                The file to periodically write sampled rows to. If None, does not write
                rows anywhere.

        Returns:
            pandas.DataFrame:
                Sampled data.
        """
        if not batch_size_per_try:
            batch_size_per_try = num_rows * 10

        counter = 0
        num_valid = 0
        prev_num_valid = None
        remaining = num_rows
        sampled = pd.DataFrame()

        while num_valid < num_rows:
            if counter >= max_tries:
                break

            prev_num_valid = num_valid
            sampled, num_valid = self._sample_rows(
                batch_size_per_try,
                conditions,
                transformed_conditions,
                float_rtol,
                sampled,
            )

            num_increase = min(num_valid - prev_num_valid, remaining)
            if num_increase > 0:
                if output_file_path:
                    append_kwargs = {
                        'mode': 'a',
                        'header': False
                    } if os.path.getsize(output_file_path) > 0 else {}
                    sampled.head(min(len(sampled),
                                     num_rows)).tail(num_increase).to_csv(
                                         output_file_path,
                                         index=False,
                                         **append_kwargs,
                                     )
                if progress_bar:
                    progress_bar.update(num_increase)

            remaining = num_rows - num_valid
            if remaining > 0:
                LOGGER.info(
                    f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows'
                )

            counter += 1

        return sampled.head(min(len(sampled), num_rows))

    def _make_condition_dfs(self, conditions):
        """Transform `conditions` into a list of dataframes.

        Args:
            conditions (list[sdv.sampling.Condition]):
                A list of `sdv.sampling.Condition`, where each `Condition` object
                represents a desired column value mapping and the number of rows
                to generate for that condition.

        Returns:
            list[pandas.DataFrame]:
                A list of `conditions` as dataframes.
        """
        condition_dataframes = defaultdict(list)
        for condition in conditions:
            column_values = condition.get_column_values()
            condition_dataframes[tuple(column_values.keys())].append(
                pd.DataFrame(column_values,
                             index=range(condition.get_num_rows())))

        return [
            pd.concat(condition_list, ignore_index=True)
            for condition_list in condition_dataframes.values()
        ]

    def _conditionally_sample_rows(self,
                                   dataframe,
                                   condition,
                                   transformed_condition,
                                   max_tries=None,
                                   batch_size_per_try=None,
                                   float_rtol=0.01,
                                   graceful_reject_sampling=True,
                                   progress_bar=None,
                                   output_file_path=None):
        num_rows = len(dataframe)
        sampled_rows = self._sample_batch(
            num_rows,
            max_tries,
            batch_size_per_try,
            condition,
            transformed_condition,
            float_rtol,
            progress_bar,
            output_file_path,
        )

        if len(sampled_rows) > 0:
            sampled_rows[COND_IDX] = dataframe[
                COND_IDX].values[:len(sampled_rows)]

        else:
            # Didn't get any rows.
            if not graceful_reject_sampling:
                user_msg = (
                    'Unable to sample any rows for the given conditions '
                    f'`{transformed_condition}`. ')
                if hasattr(self, '_model') and isinstance(
                        self._model,
                        copulas.multivariate.GaussianMultivariate):
                    user_msg = user_msg + (
                        'This may be because the provided values are out-of-bounds in the '
                        'current model. \nPlease try again with a different set of values.'
                    )
                else:
                    user_msg = user_msg + (
                        f'Try increasing `max_tries` (currently: {max_tries}) or increasing '
                        f'`batch_size_per_try` (currently: {batch_size_per_try}). Note that '
                        'increasing these values will also increase the sampling time.'
                    )

                raise ValueError(user_msg)

        return sampled_rows

    def _validate_file_path(self, output_file_path):
        """Validate the user-passed output file arg, and create the file."""
        output_path = None
        if output_file_path == DISABLE_TMP_FILE:
            # Temporary way of disabling the output file feature, used by HMA1.
            return output_path

        elif output_file_path:
            output_path = os.path.abspath(output_file_path)
            if os.path.exists(output_path):
                raise AssertionError(f'{output_path} already exists.')

        else:
            if os.path.exists(TMP_FILE_NAME):
                os.remove(TMP_FILE_NAME)

            output_path = TMP_FILE_NAME

        # Create the file.
        with open(output_path, 'w+'):
            pass

        return output_path

    def _randomize_samples(self, randomize_samples):
        """Randomize the samples according to user input.

        If ``randomize_samples`` is false, fix the seed that the random number generator
        uses in the underlying models.

        Args:
            randomize_samples (bool):
                Whether or not to randomize the generated samples.
        """
        if self._model is None:
            return

        if randomize_samples:
            self._set_random_state(None)
        else:
            self._set_random_state(FIXED_RNG_SEED)

    def sample(self,
               num_rows,
               randomize_samples=True,
               batch_size=None,
               output_file_path=None,
               conditions=None):
        """Sample rows from this table.

        Args:
            num_rows (int):
                Number of rows to sample. This parameter is required.
            randomize_samples (bool):
                Whether or not to use a fixed seed when sampling. Defaults
                to True.
            batch_size (int or None):
                The batch size to sample. Defaults to `num_rows`, if None.
            output_file_path (str or None):
                The file to periodically write sampled rows to. If None, does not
                write rows anywhere.
            conditions:
                Deprecated argument. Use the `sample_conditions` method with
                `sdv.sampling.Condition` objects instead.

        Returns:
            pandas.DataFrame:
                Sampled data.
        """
        if conditions is not None:
            raise TypeError(
                'This method does not support the conditions parameter. '
                'Please create `sdv.sampling.Condition` objects and pass them '
                'into the `sample_conditions` method. '
                'See User Guide or API for more details.')

        if num_rows is None:
            raise ValueError(
                'You must specify the number of rows to sample (e.g. num_rows=100).'
            )

        if num_rows == 0:
            return pd.DataFrame()

        self._randomize_samples(randomize_samples)

        output_file_path = self._validate_file_path(output_file_path)

        batch_size = min(batch_size, num_rows) if batch_size else num_rows

        sampled = []
        try:

            def _sample_function(progress_bar=None):
                for step in range(math.ceil(num_rows / batch_size)):
                    sampled_rows = self._sample_batch(
                        batch_size,
                        batch_size_per_try=batch_size,
                        progress_bar=progress_bar,
                        output_file_path=output_file_path,
                    )
                    sampled.append(sampled_rows)

                return sampled

            if batch_size == num_rows:
                sampled = _sample_function()
            else:
                sampled = progress_bar_wrapper(_sample_function, num_rows,
                                               'Sampling rows')

        except (Exception, KeyboardInterrupt) as error:
            handle_sampling_error(output_file_path == TMP_FILE_NAME,
                                  output_file_path, error)

        else:
            if output_file_path == TMP_FILE_NAME and os.path.exists(
                    output_file_path):
                os.remove(output_file_path)

        return pd.concat(
            sampled,
            ignore_index=True) if len(sampled) > 0 else pd.DataFrame()

    def _validate_conditions(self, conditions):
        """Validate the user-passed conditions."""
        for column in conditions.columns:
            if column not in self._metadata.get_fields():
                raise ValueError(
                    f'Unexpected column name `{column}`. '
                    f'Use a column name that was present in the original data.'
                )

    def _sample_with_conditions(self,
                                conditions,
                                max_tries,
                                batch_size_per_try,
                                progress_bar=None,
                                output_file_path=None):
        """Sample rows with conditions.

        Args:
            conditions (pandas.DataFrame):
                A DataFrame representing the conditions to be sampled.
            max_tries (int):
                Number of times to try sampling discarded rows. Defaults to 100.
            batch_size_per_try (int):
                The batch size to use per attempt at sampling. Defaults to 10 times
                the number of rows.
            progress_bar (tqdm.tqdm or None):
                The progress bar to update.
            output_file_path (str or None):
                The file to periodically write sampled rows to. Defaults to
                a temporary file, if None.

        Returns:
            pandas.DataFrame:
                Sampled data.

        Raises:
            ConstraintsNotMetError:
                If the conditions are not valid for the given constraints.
            ValueError:
                If any of the following happens:
                    * any of the conditions' columns are not valid.
                    * no rows could be generated.
        """
        try:
            transformed_conditions = self._metadata.transform(
                conditions, is_condition=True)
        except ConstraintsNotMetError as cnme:
            cnme.message = 'Provided conditions are not valid for the given constraints'
            raise

        condition_columns = list(conditions.columns)
        transformed_columns = list(transformed_conditions.columns)
        conditions.index.name = COND_IDX
        conditions.reset_index(inplace=True)
        transformed_conditions.index.name = COND_IDX
        transformed_conditions.reset_index(inplace=True)
        grouped_conditions = conditions.groupby(condition_columns)

        # sample
        all_sampled_rows = list()

        for group, dataframe in grouped_conditions:
            if not isinstance(group, tuple):
                group = [group]

            condition_indices = dataframe[COND_IDX]
            condition = dict(zip(condition_columns, group))
            if len(transformed_columns) == 0:
                sampled_rows = self._conditionally_sample_rows(
                    dataframe,
                    condition,
                    None,
                    max_tries,
                    batch_size_per_try,
                    progress_bar=progress_bar,
                    output_file_path=output_file_path,
                )
                all_sampled_rows.append(sampled_rows)
            else:
                transformed_conditions_in_group = transformed_conditions.loc[
                    condition_indices]
                transformed_groups = transformed_conditions_in_group.groupby(
                    transformed_columns)
                for transformed_group, transformed_dataframe in transformed_groups:
                    if not isinstance(transformed_group, tuple):
                        transformed_group = [transformed_group]

                    transformed_condition = dict(
                        zip(transformed_columns, transformed_group))
                    sampled_rows = self._conditionally_sample_rows(
                        transformed_dataframe,
                        condition,
                        transformed_condition,
                        max_tries,
                        batch_size_per_try,
                        progress_bar=progress_bar,
                        output_file_path=output_file_path,
                    )
                    all_sampled_rows.append(sampled_rows)

        all_sampled_rows = pd.concat(all_sampled_rows)
        if len(all_sampled_rows) == 0:
            return all_sampled_rows

        all_sampled_rows = all_sampled_rows.set_index(COND_IDX)
        all_sampled_rows.index.name = conditions.index.name
        all_sampled_rows = all_sampled_rows.sort_index()
        all_sampled_rows = self._metadata.make_ids_unique(all_sampled_rows)

        return all_sampled_rows

    def sample_conditions(self,
                          conditions,
                          max_tries=100,
                          batch_size_per_try=None,
                          randomize_samples=True,
                          output_file_path=None):
        """Sample rows from this table with the given conditions.

        Args:
            conditions (list[sdv.sampling.Condition]):
                A list of sdv.sampling.Condition objects, which specify the column
                values in a condition, along with the number of rows for that
                condition.
            max_tries (int):
                Number of times to try sampling discarded rows. Defaults to 100.
            batch_size_per_try (int):
                The batch size to use per attempt at sampling. Defaults to 10 times
                the number of rows.
            randomize_samples (bool):
                Whether or not to use a fixed seed when sampling. Defaults
                to True.
            output_file_path (str or None):
                The file to periodically write sampled rows to. Defaults to
                a temporary file, if None.

        Returns:
            pandas.DataFrame:
                Sampled data.

        Raises:
            ConstraintsNotMetError:
                If the conditions are not valid for the given constraints.
            ValueError:
                If any of the following happens:
                    * any of the conditions' columns are not valid.
                    * no rows could be generated.
        """
        return self._sample_conditions(conditions, max_tries,
                                       batch_size_per_try, randomize_samples,
                                       output_file_path)

    def _sample_conditions(self, conditions, max_tries, batch_size_per_try,
                           randomize_samples, output_file_path):
        """Sample rows from this table with the given conditions."""
        output_file_path = self._validate_file_path(output_file_path)

        num_rows = functools.reduce(
            lambda num_rows, condition: condition.get_num_rows() + num_rows,
            conditions, 0)

        conditions = self._make_condition_dfs(conditions)
        for condition_dataframe in conditions:
            self._validate_conditions(condition_dataframe)

        self._randomize_samples(randomize_samples)

        sampled = pd.DataFrame()
        try:

            def _sample_function(progress_bar=None):
                sampled = pd.DataFrame()
                for condition_dataframe in conditions:
                    sampled_for_condition = self._sample_with_conditions(
                        condition_dataframe,
                        max_tries,
                        batch_size_per_try,
                        progress_bar,
                        output_file_path,
                    )
                    sampled = pd.concat([sampled, sampled_for_condition],
                                        ignore_index=True)

                return sampled

            if len(conditions) == 1 and max_tries == 1:
                sampled = _sample_function()
            else:
                sampled = progress_bar_wrapper(_sample_function, num_rows,
                                               'Sampling conditions')

            check_num_rows(
                len(sampled),
                num_rows,
                (hasattr(self, '_model') and not isinstance(
                    self._model, copulas.multivariate.GaussianMultivariate)),
                max_tries,
                batch_size_per_try,
            )

        except (Exception, KeyboardInterrupt) as error:
            handle_sampling_error(output_file_path == TMP_FILE_NAME,
                                  output_file_path, error)

        else:
            if output_file_path == TMP_FILE_NAME and os.path.exists(
                    output_file_path):
                os.remove(output_file_path)

        return sampled

    def sample_remaining_columns(self,
                                 known_columns,
                                 max_tries=100,
                                 batch_size_per_try=None,
                                 randomize_samples=True,
                                 output_file_path=None):
        """Sample rows from this table.

        Args:
            known_columns (pandas.DataFrame):
                A pandas.DataFrame with the columns that are already known. The output
                is a DataFrame such that each row in the output is sampled
                conditionally on the corresponding row in the input.
            max_tries (int):
                Number of times to try sampling discarded rows. Defaults to 100.
            batch_size_per_try (int):
                The batch size to use per attempt at sampling. Defaults to 10 times
                the number of rows.
            randomize_samples (bool):
                Whether or not to use a fixed seed when sampling. Defaults
                to True.
            output_file_path (str or None):
                The file to periodically write sampled rows to. Defaults to
                a temporary file, if None.

        Returns:
            pandas.DataFrame:
                Sampled data.

        Raises:
            ConstraintsNotMetError:
                If the conditions are not valid for the given constraints.
            ValueError:
                If any of the following happens:
                    * any of the conditions' columns are not valid.
                    * no rows could be generated.
        """
        return self._sample_remaining_columns(known_columns, max_tries,
                                              batch_size_per_try,
                                              randomize_samples,
                                              output_file_path)

    def _sample_remaining_columns(self, known_columns, max_tries,
                                  batch_size_per_try, randomize_samples,
                                  output_file_path):
        """Sample the remaining columns of a given DataFrame."""
        output_file_path = self._validate_file_path(output_file_path)

        self._randomize_samples(randomize_samples)

        known_columns = known_columns.copy()
        self._validate_conditions(known_columns)
        sampled = pd.DataFrame()
        try:

            def _sample_function(progress_bar=None):
                return self._sample_with_conditions(known_columns, max_tries,
                                                    batch_size_per_try,
                                                    progress_bar,
                                                    output_file_path)

            if len(known_columns) == 1 and max_tries == 1:
                sampled = _sample_function()
            else:
                sampled = progress_bar_wrapper(_sample_function,
                                               len(known_columns),
                                               'Sampling remaining columns')

            check_num_rows(
                len(sampled),
                len(known_columns),
                (hasattr(self, '_model') and isinstance(
                    self._model, copulas.multivariate.GaussianMultivariate)),
                max_tries,
                batch_size_per_try,
            )

        except (Exception, KeyboardInterrupt) as error:
            handle_sampling_error(output_file_path == TMP_FILE_NAME,
                                  output_file_path, error)

        else:
            if output_file_path == TMP_FILE_NAME and os.path.exists(
                    output_file_path):
                os.remove(output_file_path)

        return sampled

    def _get_parameters(self):
        raise NonParametricError()

    def get_parameters(self):
        """Get the parameters learned from the data.

        The result is a flat dict (single level) which contains
        all the necessary parameters to be able to reproduce
        this model.

        Subclasses which are not parametric, such as DeepLearning
        based models, raise a NonParametricError indicating that
        this method is not supported for their implementation.

        Returns:
            parameters (dict):
                flat dict (single level) which contains all the
                necessary parameters to be able to reproduce
                this model.

        Raises:
            NonParametricError:
                If the model is not parametric or cannot be described
                using a simple dictionary.
        """
        if self._metadata.get_dtypes(ids=False):
            parameters = self._get_parameters()
        else:
            parameters = {}

        parameters['num_rows'] = self._num_rows
        return parameters

    def _set_parameters(self, parameters):
        raise NonParametricError()

    def set_parameters(self, parameters):
        """Regenerate a previously learned model from its parameters.

        Subclasses which are not parametric, such as DeepLearning
        based models, raise a NonParametricError indicating that
        this method is not supported for their implementation.

        Args:
            dict:
                Model parameters.

        Raises:
            NonParametricError:
                If the model is not parametric or cannot be described
                using a simple dictionary.
        """
        num_rows = parameters.pop('num_rows')
        self._num_rows = 0 if pd.isnull(num_rows) else max(
            0, int(round(num_rows)))

        if self._metadata.get_dtypes(ids=False):
            self._set_parameters(parameters)

    def save(self, path):
        """Save this model instance to the given path using pickle.

        Args:
            path (str):
                Path where the SDV instance will be serialized.
        """
        self._package_versions = get_package_versions(
            getattr(self, '_model', None))

        with open(path, 'wb') as output:
            pickle.dump(self, output)

    @classmethod
    def load(cls, path):
        """Load a TabularModel instance from a given path.

        Args:
            path (str):
                Path from which to load the instance.

        Returns:
            TabularModel:
                The loaded tabular model.
        """
        with open(path, 'rb') as f:
            model = pickle.load(f)
            throw_version_mismatch_warning(
                getattr(model, '_package_versions', None))

            return model