def _configure_range_supplier_for_data(field_spec, data): """ configures the supplier based on the range data supplied """ config = field_spec.get('config', {}) precision = config.get('precision', None) if precision and not str(precision).isnumeric(): raise datacraft.SpecException( f'precision must be valid integer {json.dumps(field_spec)}') start = data[0] # default for built in range function is exclusive end, we want to default to inclusive as this is the # more intuitive behavior end = data[1] + 1 if not end > start: raise datacraft.SpecException( f'end element must be larger than start:{json.dumps(field_spec)}') if len(data) == 2: step = 1 else: step = data[2] try: return datacraft.suppliers.range_supplier(start, end, step, precision=precision) except ValueError as err: raise datacraft.SpecException(str(err)) from err
def _configure_select_list_subset_supplier(field_spec, loader): """ configures supplier for select_list_subset type """ config = datacraft.utils.load_config(field_spec, loader) data = None if config is None or ('mean' not in config and 'count' not in config): raise datacraft.SpecException(f'Config with mean or count defined must be provided: {json.dumps(field_spec)}') if 'ref' in field_spec and 'data' in field_spec: raise datacraft.SpecException(f'Only one of "data" or "ref" can be provided for:{json.dumps(field_spec)}') if 'ref' in field_spec: ref_name = field_spec.get('ref') field_spec = loader.get_ref(ref_name) if field_spec is None: raise datacraft.SpecException(f'No ref with name %s found: {ref_name}, {json.dumps(field_spec)}') if 'data' in field_spec: data = field_spec.get('data') else: data = field_spec elif 'data' in field_spec: data = field_spec.get('data') if data is None: raise datacraft.SpecException( 'Unable to identify data for ' + _SELECT_LIST_SUBSET_KEY + ' for spec: ' + json.dumps(field_spec)) return datacraft.suppliers.select_list_subset(data, **config)
def _configure_unicode_range_supplier(spec, loader): """ configure the supplier for unicode_range types """ if 'data' not in spec: raise datacraft.SpecException( 'data is Required Element for unicode_range specs: ' + json.dumps(spec)) data = spec['data'] if not isinstance(data, list): raise datacraft.SpecException( f'data should be a list or list of lists with two elements for {_UNICODE_RANGE_KEY} specs: ' + json.dumps(spec)) config = datacraft.utils.load_config(spec, loader) return datacraft.suppliers.unicode_range(data, **config)
def _configure_precise_ip(field_spec, _): """ configures value supplier for ip.precise type """ config = field_spec.get('config') if config is None: raise datacraft.SpecException('No config for: ' + json.dumps(field_spec) + ', param cidr required') cidr = config.get('cidr') sample = datacraft.utils.is_affirmative('sample', config, 'no') if cidr is None: raise datacraft.SpecException('Invalid config for: ' + json.dumps(field_spec) + ', param cidr required') return datacraft.suppliers.ip_precise(cidr, sample)
def configure_supplier(field_spec, loader): """ Configures the supplier from the provided field spec using the huggingface fill-mask pipeline by default :param loader: datacraft.Loader object :param field_spec: specification for the hf-fill-mask field Config Params: :key mask-token-placeholder: place holder that should show up in the seed strings, default '__MASK__' :key pipeline: name of the transformers pipeline to use, default is 'fill-mask' :key model-dir: directory to load model from, default is loader.datadir :key token-only: if only the generated token should be output apart from the context, default is to output the full sequence """ if 'seed-ref' not in field_spec: raise datacraft.SpecException('seed-ref is required field for hf-fill-mask type: ' + json.dumps(field_spec)) key = field_spec.get('seed-ref') seed_ref_spec = loader.refs.get(key) config = field_spec.get('config', {}) mask_token_placeholder = config.get('mask-token-placeholder', '__MASK__') pipeline_name = config.get('pipeline', 'fill-mask') model_dir = config.get('model-dir', loader.datadir) token_only = config.get('token-only', False) # This is the supplier for the inputs to the transformer pipeline wrapped = loader.get_from_spec(seed_ref_spec) return HuggingFaceFillMaskSupplier(wrapped, mask_token_placeholder, pipeline_name, token_only, model_dir)
def _configure_ip(field_spec, loader): """ configures value supplier for ip type """ config = datacraft.utils.load_config(field_spec, loader) try: return datacraft.suppliers.ip_supplier(**config) except ValueError as err: raise datacraft.SpecException(str(err)) from err
def _configure_distribution_supplier(field_spec, _): """ configure the supplier for distribution types """ if 'data' not in field_spec: raise datacraft.SpecException( 'required data element not defined for ' + _DISTRIBUTION_KEY + ' type : ' + json.dumps(field_spec)) distribution = datacraft.distributions.from_string(field_spec['data']) return datacraft.suppliers.distribution_supplier(distribution)
def _configure_csv(field_spec, loader): """ Configures the csv value supplier for this field """ config = datacraft.utils.load_config(field_spec, loader) datafile = config.get('datafile', datacraft.registries.get_default('csv_file')) csv_path = f'{loader.datadir}/{datafile}' if not os.path.exists(csv_path): raise datacraft.SpecException(f'Unable to locate data file: {datafile} in data dir: {loader.datadir} for spec: ' + json.dumps(field_spec)) return datacraft.suppliers.csv(csv_path, **config)
def _configure_templated_type(field_spec, loader): if 'data' not in field_spec: raise datacraft.SpecException( f'data is required field for templated specs: {json.dumps(field_spec)}' ) suppliers_map = build_suppliers_map(field_spec, loader) return datacraft.suppliers.templated(suppliers_map, field_spec.get('data', None))
def _configure_calculate_supplier(field_spec: dict, loader: datacraft.Loader): """ configures supplier for calculate type """ formula = field_spec.get('formula') if formula is None: raise datacraft.SpecException(f'Must define formula for calculate type. {json.dumps(field_spec)}') suppliers_map = build_suppliers_map(field_spec, loader) return datacraft.suppliers.calculate(suppliers_map=suppliers_map, formula=formula)
def _configure_combine_list_supplier(field_spec, loader): """ configures supplier for combine-list type """ if 'refs' not in field_spec: raise datacraft.SpecException( f'Must define refs for combine-list type. {json.dumps(field_spec)}' ) refs_list = field_spec['refs'] if len(refs_list) < 1 or not isinstance(refs_list[0], list): raise datacraft.SpecException( f'refs pointer must be list of lists: i.e [["ONE", "TWO"]]. {json.dumps(field_spec)}' ) suppliers_list = [] for ref in refs_list: spec = dict(field_spec) spec['refs'] = ref suppliers_list.append(_load_combine_from_refs(spec, loader)) return datacraft.suppliers.from_list_of_suppliers(suppliers_list, True)
def build_suppliers_map(field_spec: dict, loader: datacraft.Loader) -> dict: """uses refs or fields to build a map for those suppliers""" if 'refs' not in field_spec and 'fields' not in field_spec: raise datacraft.SpecException( f'Must define one of fields or refs. {json.dumps(field_spec)}') if 'refs' in field_spec and 'fields' in field_spec: raise datacraft.SpecException( f'Must define only one of fields or refs. {json.dumps(field_spec)}' ) mappings = _get_mappings(field_spec, 'refs') mappings.update(_get_mappings(field_spec, 'fields')) if len(mappings) < 1: raise datacraft.SpecException( f'fields or refs empty: {json.dumps(field_spec)}') suppliers_map = {} for field_or_ref, alias in mappings.items(): supplier = loader.get(field_or_ref) suppliers_map[alias] = supplier return suppliers_map
def _configure_range_supplier(field_spec, _): """ configures the range value supplier """ if 'data' not in field_spec: raise datacraft.SpecException( f'No data element defined for: {json.dumps(field_spec)}') data = field_spec.get('data') if not isinstance(data, list) or len(data) < 2: raise datacraft.SpecException( f'data element for ranges type must be list with at least two elements:{json.dumps(field_spec)}' ) # we have the nested case if isinstance(data[0], list): suppliers_list = [ _configure_range_supplier_for_data(field_spec, subdata) for subdata in data ] return datacraft.suppliers.from_list_of_suppliers(suppliers_list, True) return _configure_range_supplier_for_data(field_spec, data)
def _configure_combine_supplier(field_spec, loader): """ configures supplier for combine type """ if 'refs' not in field_spec and 'fields' not in field_spec: raise datacraft.SpecException( f'Must define one of fields or refs. {json.dumps(field_spec)}') if 'refs' in field_spec: supplier = _load_combine_from_refs(field_spec, loader) else: supplier = _load_combine_from_fields(field_spec, loader) return supplier
def _configure_uuid_supplier(field_spec, loader): """ configure the supplier for uuid types """ config = datacraft.utils.load_config(field_spec, loader) variant = int( config.get('variant', datacraft.registries.get_default('uuid_variant'))) if variant not in [1, 3, 4, 5]: raise datacraft.SpecException('Invalid variant for: ' + json.dumps(field_spec)) return datacraft.suppliers.uuid(variant)
def _configure_ref_supplier(field_spec: dict, loader: datacraft.Loader): """ configures supplier for ref type """ key = None if 'data' in field_spec: key = field_spec.get('data') if 'ref' in field_spec: key = field_spec.get('ref') if key is None: raise datacraft.SpecException('No key found for spec: ' + json.dumps(field_spec)) return loader.get(key)
def next(self, iteration): value = str(self.wrapped.next(iteration)) if self.mask_token_placeholder not in value: raise datacraft.SpecException( f'Mask token placeholder: {self.mask_token_placeholder} not found in generated data!') value = value.replace(self.mask_token_placeholder, self.mask_token) candidates = self.nlp(value) # just take a random candidate candidate = random.sample(candidates, 1)[0] if self.token_only: return candidate['token_str'] return candidate['sequence']
def _configure_rand_range_supplier(field_spec, loader): """ configures the random range value supplier """ if 'data' not in field_spec: raise datacraft.SpecException( f'No data element defined for: {json.dumps(field_spec)}') data = field_spec.get('data') config = datacraft.utils.load_config(field_spec, loader) if not isinstance(data, list) or len(data) == 0: raise datacraft.SpecException( f'rand_range specs require data as array with at least one element: {json.dumps(field_spec)}' ) start = 0 end = 0 if len(data) == 1: end = data[0] if len(data) >= 2: start = data[0] end = data[1] precision = None if len(data) > 2: precision = data[2] # config overrides third data element if specified precision = config.get('precision', precision) return datacraft.suppliers.random_range(start, end, precision)
def _configure_char_class_supplier(spec, loader): """ configure the supplier for char_class types """ if 'data' not in spec: raise datacraft.SpecException( f'Data is required field for char_class type: {json.dumps(spec)}') config = datacraft.utils.load_config(spec, loader) data = spec['data'] if isinstance(data, str) and data in _CLASS_MAPPING: data = _CLASS_MAPPING[data] if isinstance(data, list): new_data = [ _CLASS_MAPPING[datum] if datum in _CLASS_MAPPING else datum for datum in data ] data = ''.join(new_data) if 'join_with' not in config: config['join_with'] = datacraft.registries.get_default( 'char_class_join_with') return datacraft.suppliers.character_class(data, **config)
def _configure_weighted_csv(field_spec, loader): """ Configures the weighted_csv value supplier for this field """ config = datacraft.utils.load_config(field_spec, loader) field_name = config.get('column', 1) weight_column = config.get('weight_column', 2) count_supplier = datacraft.suppliers.count_supplier(**config) datafile = config.get('datafile', datacraft.registries.get_default('csv_file')) csv_path = f'{loader.datadir}/{datafile}' has_headers = datacraft.utils.is_affirmative('headers', config) numeric_index = isinstance(field_name, int) if numeric_index and field_name < 1: raise datacraft.SpecException(f'Invalid index {field_name}, one based indexing used for column numbers') if has_headers and not numeric_index: choices = _read_named_column(csv_path, field_name) weights = _read_named_column_weights(csv_path, weight_column) else: choices = _read_indexed_column(csv_path, int(field_name), skip_first=numeric_index) weights = _read_indexed_column_weights(csv_path, int(weight_column), skip_first=numeric_index) return weighted_values_explicit(choices, weights, count_supplier)