def gen_set_params(algo, pnames, schema_params, required_params, skip_params=None): if skip_params: yield " # formally define variables that were excluded from function parameters" for pname in skip_params: yield " %s <- NULL" % pname validate_frames = get_customizations_or_defaults_for(algo, 'extensions.validate_frames') if validate_frames: yield " # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object" yield reformat_block(validate_frames, indent=2) else: frames = get_customizations_or_defaults_for(algo, 'extensions.frame_params', []) if frames: yield " # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object" for frame in frames: if frame in pnames: required_val = str(frame in required_params).upper() yield " {frame} <- .validate.H2OFrame({frame}, required={required})".format(frame=frame, required=required_val) validate_required_params = get_customizations_or_defaults_for(algo, 'extensions.validate_required_params') if validate_required_params: yield "" yield " # Validate other required args" yield reformat_block(validate_required_params, indent=2) validate_params = get_customizations_or_defaults_for(algo, 'extensions.validate_params') if validate_params: yield "" yield " # Validate other args" yield reformat_block(validate_params, indent=2) yield "" yield " # Build parameter list to send to model builder" yield " parms <- list()" set_required_params = get_customizations_or_defaults_for(algo, 'extensions.set_required_params') if set_required_params: yield reformat_block(set_required_params, indent=2) skip_default_set_params = get_customizations_or_defaults_for(algo, 'extensions.skip_default_set_params_for', []) yield "" for pname in schema_params: if pname in skip_default_set_params or (skip_params and pname in skip_params): continue # leave the special handling of 'loss' param here for now as it is used by several algos if pname == "loss": yield " if(!missing(loss)) {" yield " if(loss == \"MeanSquare\") {" yield " warning(\"Loss name 'MeanSquare' is deprecated; please use 'Quadratic' instead.\")" yield " parms$loss <- \"Quadratic\"" yield " } else " yield " parms$loss <- loss" yield " }" else: yield " if (!missing(%s))" % pname yield " parms$%s <- %s" % (pname, pname) set_params = get_customizations_or_defaults_for(algo, 'extensions.set_params') if set_params: yield "" yield reformat_block(set_params, indent=2)
def gen_module(schema, algo, module): # print(str(schema)) rest_api_version = get_customizations_for(algo, 'rest_api_version', 3) doc_preamble = get_customizations_for(algo, 'doc.preamble') doc_returns = get_customizations_for(algo, 'doc.returns') doc_seealso = get_customizations_for(algo, 'doc.seealso') doc_references = get_customizations_for(algo, 'doc.references') doc_examples = get_customizations_for(algo, 'doc.examples') required_params = get_customizations_or_defaults_for( algo, 'extensions.required_params', []) extra_params = get_customizations_or_defaults_for( algo, 'extensions.extra_params', []) model_name = algo_to_modelname(algo) update_param_defaults = get_customizations_for('defaults', 'update_param') update_param = get_customizations_for(algo, 'update_param') yield "# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_R.py" yield "# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details) \n#'" yield "# -------------------------- %s -------------------------- #" % model_name # start documentation if doc_preamble: yield "#'" yield reformat_block(doc_preamble, prefix="#' ") yield "#'" # start doc for signature required_params = odict([(p[0] if isinstance(p, tuple) else p, p[1] if isinstance(p, tuple) else None) for p in required_params]) schema_params = odict([(p['name'], p) for p in schema['parameters']]) extra_params = odict([(p[0] if isinstance(p, tuple) else p, p[1] if isinstance(p, tuple) else None) for p in extra_params]) all_params = list(required_params.keys()) + list( schema_params.keys()) + list(extra_params.keys()) def get_schema_params(pname): param = deepcopy(schema_params[pname]) updates = None for update_fn in [update_param, update_param_defaults]: if callable(update_fn): updates = update_fn(pname, param) if updates is not None: param = updates break return param if isinstance(param, (list, tuple)) else [ param ] # always return array to support deprecated aliases tag = "@param" pdocs = odict() for pname in all_params: if pname in pdocs: # avoid duplicates (esp. if already included in required_params) continue if pname in schema_params: for param in get_schema_params( pname): # retrieve potential aliases pname = param.get('name') if pname: pdocs[pname] = get_customizations_or_defaults_for( algo, 'doc.params.' + pname, get_help(param, indent=len(tag) + 4)) else: pdocs[pname] = get_customizations_or_defaults_for( algo, 'doc.params.' + pname) for pname, pdoc in pdocs.items(): if pdoc: yield reformat_block("%s %s %s" % (tag, pname, pdoc.lstrip('\n')), indent=len(tag) + 1, indent_first=False, prefix="#' ") if doc_returns: tag = "@return" yield reformat_block("%s %s" % (tag, doc_returns.lstrip('\n')), indent=len(tag) + 1, indent_first=False, prefix="#' ") if doc_seealso: tag = "@seealso" yield reformat_block("%s %s" % (tag, doc_seealso.lstrip('\n')), indent=len(tag) + 1, indent_first=False, prefix="#' ") if doc_references: tag = "@references" yield reformat_block("%s %s" % (tag, doc_references.lstrip('\n')), indent=len(tag) + 1, indent_first=False, prefix="#' ") if doc_examples: yield "#' @examples" yield "#' \dontrun{" yield reformat_block(doc_examples, prefix="#' ") yield "#' }" yield "#' @export" # start function signature sig_pnames = [] sig_params = [] for k, v in required_params.items(): sig_pnames.append(k) sig_params.append(k if v is None else '%s = %s' % (k, v)) for pname in schema_params: params = get_schema_params(pname) for param in params: pname = param.get( 'name') # override local var as param can be an alias of pname if pname in required_params or not pname: # skip schema params already added by required_params, and those explicitly removed continue sig_pnames.append(pname) sig_params.append("%s = %s" % (pname, get_sig_default_value(param))) for k, v in extra_params.items(): sig_pnames.append(k) sig_params.append("%s = %s" % (k, v)) param_indent = len("h2o.%s <- function(" % module) yield reformat_block("h2o.%s <- function(%s)" % (module, ',\n'.join(sig_params)), indent=param_indent, indent_first=False) # start function body yield "{" validate_frames = get_customizations_or_defaults_for( algo, 'extensions.validate_frames') if validate_frames: yield " # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object" yield reformat_block(validate_frames, indent=2) else: frames = get_customizations_or_defaults_for(algo, 'extensions.frame_params', []) if frames: yield " # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object" for frame in frames: if frame in sig_pnames: required_val = str(frame in required_params).upper() yield " {frame} <- .validate.H2OFrame({frame}, required={required})".format( frame=frame, required=required_val) validate_required_params = get_customizations_or_defaults_for( algo, 'extensions.validate_required_params') if validate_required_params: yield "" yield " # Validate other required args" yield reformat_block(validate_required_params, indent=2) validate_params = get_customizations_or_defaults_for( algo, 'extensions.validate_params') if validate_params: yield "" yield " # Validate other args" yield reformat_block(validate_params, indent=2) yield "" yield " # Build parameter list to send to model builder" yield " parms <- list()" set_required_params = get_customizations_or_defaults_for( algo, 'extensions.set_required_params') if set_required_params: yield reformat_block(set_required_params, indent=2) skip_default_set_params = get_customizations_or_defaults_for( algo, 'extensions.skip_default_set_params_for', []) yield "" for pname in schema_params: if pname in skip_default_set_params: continue # leave the special handling of 'loss' param here for now as it is used by several algos if pname == "loss": yield " if(!missing(loss)) {" yield " if(loss == \"MeanSquare\") {" yield " warning(\"Loss name 'MeanSquare' is deprecated; please use 'Quadratic' instead.\")" yield " parms$loss <- \"Quadratic\"" yield " } else " yield " parms$loss <- loss" yield " }" else: yield " if (!missing(%s))" % pname yield " parms$%s <- %s" % (pname, pname) set_params = get_customizations_or_defaults_for(algo, 'extensions.set_params') if set_params: yield "" yield reformat_block(set_params, indent=2) yield "" yield " # Error check and build model" verbose = 'verbose' if 'verbose' in extra_params else 'FALSE' yield " model <- .h2o.modelJob('%s', parms, h2oRestApiVersion=%d, verbose=%s)" % ( algo, rest_api_version, verbose) with_model = get_customizations_for(algo, 'extensions.with_model') if with_model: yield "" yield reformat_block(with_model, indent=2) yield " return(model)" yield "}" # start additional functions module_extensions = get_customizations_for(algo, 'extensions.module') if module_extensions: yield "" yield module_extensions
def gen_module(schema, algo, module): # print(str(schema)) rest_api_version = get_customizations_for(algo, 'rest_api_version', 3) doc_preamble = get_customizations_for(algo, 'doc.preamble') doc_returns = get_customizations_for(algo, 'doc.returns') doc_seealso = get_customizations_for(algo, 'doc.seealso') doc_references = get_customizations_for(algo, 'doc.references') doc_examples = get_customizations_for(algo, 'doc.examples') required_params = get_customizations_or_defaults_for(algo, 'extensions.required_params', []) extra_params = get_customizations_or_defaults_for(algo, 'extensions.extra_params', []) model_name = algo_to_modelname(algo) update_param_defaults = get_customizations_for('defaults', 'update_param') update_param = get_customizations_for(algo, 'update_param') yield "# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_R.py" yield "# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details) \n#'" yield "# -------------------------- %s -------------------------- #" % model_name # start documentation if doc_preamble: yield "#'" yield reformat_block(doc_preamble, prefix="#' ") yield "#'" # start doc for signature required_params = odict([(p[0] if isinstance(p, tuple) else p, p[1] if isinstance(p, tuple) else None) for p in required_params]) schema_params = odict([(p['name'], p) for p in schema['parameters']]) extra_params = odict([(p[0] if isinstance(p, tuple) else p, p[1] if isinstance(p, tuple) else None) for p in extra_params]) all_params = list(required_params.keys()) + list(schema_params.keys()) + list(extra_params.keys()) def get_schema_params(pname): param = deepcopy(schema_params[pname]) updates = None for update_fn in [update_param, update_param_defaults]: if callable(update_fn): updates = update_fn(pname, param) if updates is not None: param = updates break return param if isinstance(param, (list, tuple)) else [param] # always return array to support deprecated aliases tag = "@param" pdocs = odict() for pname in all_params: if pname in pdocs: # avoid duplicates (esp. if already included in required_params) continue if pname in schema_params: for param in get_schema_params(pname): # retrieve potential aliases pname = param.get('name') if pname: pdocs[pname] = get_customizations_or_defaults_for(algo, 'doc.params.'+pname, get_help(param, indent=len(tag)+4)) else: pdocs[pname] = get_customizations_or_defaults_for(algo, 'doc.params.'+pname) for pname, pdoc in pdocs.items(): if pdoc: yield reformat_block("%s %s %s" % (tag, pname, pdoc.lstrip('\n')), indent=len(tag)+1, indent_first=False, prefix="#' ") if doc_returns: tag = "@return" yield reformat_block("%s %s" % (tag, doc_returns.lstrip('\n')), indent=len(tag)+1, indent_first=False, prefix="#' ") if doc_seealso: tag = "@seealso" yield reformat_block("%s %s" % (tag, doc_seealso.lstrip('\n')), indent=len(tag)+1, indent_first=False, prefix="#' ") if doc_references: tag = "@references" yield reformat_block("%s %s" % (tag, doc_references.lstrip('\n')), indent=len(tag)+1, indent_first=False, prefix="#' ") if doc_examples: yield "#' @examples" yield "#' \dontrun{" yield reformat_block(doc_examples, prefix="#' ") yield "#' }" yield "#' @export" # start function signature sig_pnames = [] sig_params = [] for k, v in required_params.items(): sig_pnames.append(k) sig_params.append(k if v is None else '%s = %s' % (k, v)) for pname in schema_params: params = get_schema_params(pname) for param in params: pname = param.get('name') # override local var as param can be an alias of pname if pname in required_params or not pname: # skip schema params already added by required_params, and those explicitly removed continue sig_pnames.append(pname) sig_params.append("%s = %s" % (pname, get_sig_default_value(param))) for k, v in extra_params.items(): sig_pnames.append(k) sig_params.append("%s = %s" % (k, v)) param_indent = len("h2o.%s <- function(" % module) yield reformat_block("h2o.%s <- function(%s)" % (module, ',\n'.join(sig_params)), indent=param_indent, indent_first=False) # start function body yield "{" yield '\n'.join(gen_set_params(algo, sig_pnames, schema_params, required_params)) yield "" yield " # Error check and build model" verbose = 'verbose' if 'verbose' in extra_params else 'FALSE' yield " model <- .h2o.modelJob('%s', parms, h2oRestApiVersion=%d, verbose=%s)" % (algo, rest_api_version, verbose) with_model = get_customizations_for(algo, 'extensions.with_model') if with_model: yield "" yield reformat_block(with_model, indent=2) yield " return(model)" yield "}" bulk_pnames_skip = ["model_id", "verbose", "destination_key"] # destination_key is only for SVD bulk_params = list(zip(*filter(lambda t: not t[0] in bulk_pnames_skip, zip(sig_pnames, sig_params)))) bulk_pnames = list(bulk_params[0]) sig_bulk_params = list(bulk_params[1]) sig_bulk_params.append("segment_columns = NULL") sig_bulk_params.append("segment_models_id = NULL") sig_bulk_params.append("parallelism = 1") if algo != "generic": # # Segment model building # bulk_param_indent = len(".h2o.train_segments_%s <- function(" % module) yield reformat_block(".h2o.train_segments_%s <- function(%s)" % (module, ',\n'.join(sig_bulk_params)), indent=bulk_param_indent, indent_first=False) # start train_segments-function body yield "{" yield '\n'.join(gen_set_params(algo, bulk_pnames, schema_params, required_params, bulk_pnames_skip)) yield "" yield " # Build segment-models specific parameters" yield " segment_parms <- list()" yield " if (!missing(segment_columns))" yield " segment_parms$segment_columns <- segment_columns" yield " if (!missing(segment_models_id))" yield " segment_parms$segment_models_id <- segment_models_id" yield " segment_parms$parallelism <- parallelism" yield "" yield " # Error check and build segment models" yield " segment_models <- .h2o.segmentModelsJob('%s', segment_parms, parms, h2oRestApiVersion=%d)" % (algo, rest_api_version) yield " return(segment_models)" yield "}" # # Additional functions # module_extensions = get_customizations_for(algo, 'extensions.module') if module_extensions: yield "" yield module_extensions
def gen_module(schema, algo): """ Ideally we should be able to avoid logic specific to algos in this file. Instead, customizations are externalized in ./python/gen_{algo}.py files. Logic that is specific to python types (e.g. H2OFrame, enums as list...) should however stay here as the type translation is done in this file. """ classname = algo_to_classname(algo) extra_imports = get_customizations_for(algo, 'extensions.__imports__') class_doc = get_customizations_for(algo, 'doc.__class__') class_examples = get_customizations_for(algo, 'examples.__class__') class_extras = get_customizations_for(algo, 'extensions.__class__') module_extras = get_customizations_for(algo, 'extensions.__module__') update_param_defaults = get_customizations_for('defaults', 'update_param') update_param = get_customizations_for(algo, 'update_param') deprecated_params = get_customizations_for(algo, 'deprecated_params', {}) def extend_schema_params(param): pname = param.get('name') param = deepcopy(param) updates = None for update_fn in [update_param, update_param_defaults]: if callable(update_fn): updates = update_fn(pname, param) if updates is not None: param = updates break # return param if isinstance(param, (list, tuple)) else [param] # always return array to support deprecated aliases return param extended_params = [extend_schema_params(p) for p in schema['parameters']] param_names = [] for param in extended_params: pname = param.get('name') ptype = param.get('type') pvalues = param.get('values') pdefault = param.get('default_value') assert (ptype[:4] == 'enum' ) == bool(pvalues), "Values are expected for enum types only" if pvalues: enum_values = [normalize_enum_constant(p) for p in pvalues] if pdefault: pdefault = normalize_enum_constant(pdefault) else: enum_values = None if pname in reserved_words: pname += "_" param_names.append(pname) param['pname'] = pname param['default_value'] = pdefault param['ptype'] = translate_type_for_check(ptype, enum_values) param['dtype'] = translate_type_for_doc(ptype, enum_values) if deprecated_params: extended_params = [ p for p in extended_params if p['pname'] not in deprecated_params.keys() ] yield "#!/usr/bin/env python" yield "# -*- encoding: utf-8 -*-" yield "#" yield "# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_python.py" yield "# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details)" yield "#" yield "from __future__ import absolute_import, division, print_function, unicode_literals" yield "" if deprecated_params: yield "from h2o.utils.metaclass import deprecated_params, deprecated_property" if extra_imports: yield reformat_block(extra_imports) yield "from h2o.estimators.estimator_base import H2OEstimator" yield "from h2o.exceptions import H2OValueError" yield "from h2o.frame import H2OFrame" yield "from h2o.utils.typechecks import assert_is_type, Enum, numeric" yield "" yield "" yield "class %s(H2OEstimator):" % classname yield ' """' yield " " + schema["algo_full_name"] yield "" if class_doc: yield reformat_block(class_doc, 4) if class_examples: yield "" yield " :examples:" yield "" yield reformat_block(class_examples, 4) yield ' """' yield "" yield ' algo = "%s"' % algo yield " supervised_learning = %s" % get_customizations_for( algo, 'supervised_learning', True) options = get_customizations_for(algo, 'options') if options: yield " _options_ = %s" % reformat_block( pformat(options), prefix=' ' * 16, prefix_first=False) yield "" if deprecated_params: yield reformat_block("@deprecated_params(%s)" % deprecated_params, indent=4) init_sig = "def __init__(self,\n%s\n):" % "\n".join( "%s=%s, # type: %s" % (name, default, "Optional[%s]" % type if default is None else type) for name, default, type in [( p.get('pname'), stringify(p.get('default_value'), infinity=None), p.get('dtype')) for p in extended_params]) yield reformat_block(init_sig, indent=4, prefix=' ' * 13, prefix_first=False) yield ' """' for p in extended_params: pname, pdefault, dtype, pdoc = p.get('pname'), stringify( p.get('default_value')), p.get('dtype'), p.get('help') pdesc = "%s: %s\nDefaults to ``%s``." % (pname, pdoc, pdefault) pident = ' ' * 15 yield " :param %s" % bi.wrap( pdesc, indent=pident, indent_first=False) yield " :type %s: %s%s" % ( pname, bi.wrap(dtype, indent=pident, indent_first=False), ", optional" if pdefault is None else "") yield ' """' yield " super(%s, self).__init__()" % classname yield " self._parms = {}" for p in extended_params: pname = p.get('pname') if pname == 'model_id': yield " self._id = self._parms['model_id'] = model_id" else: yield " self.%s = %s" % (pname, pname) rest_api_version = get_customizations_for(algo, 'rest_api_version') if rest_api_version: yield ' self._parms["_rest_version"] = %s' % rest_api_version yield "" for param in extended_params: pname = param.get('pname') if pname == "model_id": continue # The getter is already defined in ModelBase sname = pname[:-1] if pname[-1] == '_' else pname ptype = param.get('ptype') dtype = param.get('dtype') pdefault = param.get('default_value') if dtype.startswith("Enum"): vals = dtype[5:-1].split(", ") property_doc = "One of: " + ", ".join("``%s``" % v for v in vals) else: property_doc = "Type: ``%s``" % dtype property_doc += ("." if pdefault is None else ", defaults to ``%s``." % stringify(pdefault)) yield " @property" yield " def %s(self):" % pname yield ' """' yield bi.wrap(param.get('help'), indent=8 * ' ') # we need to wrap only for text coming from server yield "" yield bi.wrap(property_doc, indent=8 * ' ') custom_property_doc = get_customizations_for(algo, "doc.{}".format(pname)) if custom_property_doc: yield "" yield reformat_block(custom_property_doc, 8) property_examples = get_customizations_for(algo, "examples.{}".format(pname)) if property_examples: yield "" yield " :examples:" yield "" yield reformat_block(property_examples, 8) yield ' """' property_getter = get_customizations_or_defaults_for( algo, "overrides.{}.getter".format( pname)) # check gen_stackedensemble.py for an example if property_getter: yield reformat_block(property_getter.format(**locals()), 8) else: yield " return self._parms.get(\"%s\")" % sname yield "" yield " @%s.setter" % pname yield " def %s(self, %s):" % (pname, pname) property_setter = get_customizations_or_defaults_for( algo, "overrides.{}.setter".format( pname)) # check gen_stackedensemble.py for an example if property_setter: yield reformat_block(property_setter.format(**locals()), 8) elif "H2OFrame" in ptype: yield " self._parms[\"%s\"] = H2OFrame._validate(%s, '%s')" % ( sname, pname, pname) else: yield " assert_is_type(%s, None, %s)" % (pname, ptype) yield " self._parms[\"%s\"] = %s" % (sname, pname) yield "" for old, new in deprecated_params.items(): new_name = new[0] if isinstance(new, tuple) else new yield " %s = deprecated_property('%s', %s)" % (old, old, new) yield "" if class_extras: yield reformat_block(code_as_str(class_extras), 4) if module_extras: yield "" yield reformat_block(code_as_str(module_extras))
def gen_module(schema, algo): """ Ideally we should be able to avoid logic specific to algos in this file. Instead, customizations are externalized in ./python/gen_{algo}.py files. Logic that is specific to python types (e.g. H2OFrame, enums as list...) should however stay here as the type translation is done in this file. """ classname = algo_to_classname(algo) rest_api_version = get_customizations_for(algo, 'rest_api_version') extra_imports = get_customizations_for(algo, 'extensions.__imports__') class_doc = get_customizations_for(algo, 'doc.__class__') class_examples = get_customizations_for(algo, 'examples.__class__') class_init_validation = get_customizations_for( algo, 'extensions.__init__validation') class_init_setparams = get_customizations_for( algo, 'extensions.__init__setparams') class_extras = get_customizations_for(algo, 'extensions.__class__') module_extras = get_customizations_for(algo, 'extensions.__module__') update_param_defaults = get_customizations_for('defaults', 'update_param') update_param = get_customizations_for(algo, 'update_param') def extend_schema_params(param): pname = param.get('name') param = deepcopy(param) updates = None for update_fn in [update_param, update_param_defaults]: if callable(update_fn): updates = update_fn(pname, param) if updates is not None: param = updates break # return param if isinstance(param, (list, tuple)) else [param] # always return array to support deprecated aliases return param extended_params = [extend_schema_params(p) for p in schema['parameters']] param_names = [] for param in extended_params: pname = param.get('name') ptype = param.get('type') pvalues = param.get('values') pdefault = param.get('default_value') assert (ptype[:4] == 'enum' ) == bool(pvalues), "Values are expected for enum types only" if pvalues: enum_values = [normalize_enum_constant(p) for p in pvalues] if pdefault: pdefault = normalize_enum_constant(pdefault) else: enum_values = None if pname in reserved_words: pname += "_" param_names.append(pname) param['pname'] = pname param['default_value'] = pdefault param['ptype'] = translate_type_for_check(ptype, enum_values) param['dtype'] = translate_type_for_doc(ptype, enum_values) yield "#!/usr/bin/env python" yield "# -*- encoding: utf-8 -*-" yield "#" yield "# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_python.py" yield "# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details)" yield "#" yield "from __future__ import absolute_import, division, print_function, unicode_literals" yield "" if extra_imports: yield reformat_block(extra_imports) yield "from h2o.estimators.estimator_base import H2OEstimator" yield "from h2o.exceptions import H2OValueError" yield "from h2o.frame import H2OFrame" yield "from h2o.utils.typechecks import assert_is_type, Enum, numeric" yield "" yield "" yield "class %s(H2OEstimator):" % classname yield ' """' yield " " + schema["algo_full_name"] yield "" if class_doc: yield reformat_block(class_doc, 4) if class_examples: yield "" yield " :examples:" yield "" yield reformat_block(class_examples, 4) yield ' """' yield "" yield ' algo = "%s"' % algo yield " param_names = {%s}" % bi.wrap(", ".join('"%s"' % p for p in param_names), indent=(" " * 19), indent_first=False) yield "" yield " def __init__(self, **kwargs):" # TODO: generate __init__ docstring with all params (also generate exact signature to support auto-completion) yield " super(%s, self).__init__()" % classname yield " self._parms = {}" if class_init_validation: yield reformat_block(class_init_validation, 8) yield " for pname, pvalue in kwargs.items():" yield " if pname == 'model_id':" yield " self._id = pvalue" yield ' self._parms["model_id"] = pvalue' if class_init_setparams: yield reformat_block(class_init_setparams, 12) yield " elif pname in self.param_names:" yield " # Using setattr(...) will invoke type-checking of the arguments" yield " setattr(self, pname, pvalue)" yield " else:" yield ' raise H2OValueError("Unknown parameter %s = %r" % (pname, pvalue))' if rest_api_version: yield ' self._parms["_rest_version"] = %s' % rest_api_version yield "" for param in extended_params: pname = param.get('pname') if pname == "model_id": continue # The getter is already defined in ModelBase sname = pname[:-1] if pname[-1] == '_' else pname ptype = param.get('ptype') dtype = param.get('dtype') pdefault = param.get('default_value') if dtype.startswith("Enum"): vals = dtype[5:-1].split(", ") property_doc = "One of: " + ", ".join("``%s``" % v for v in vals) else: property_doc = "Type: ``%s``" % dtype property_doc += ("." if pdefault is None else " (default: ``%s``)." % stringify(pdefault)) deprecated = pname in get_customizations_for(algo, 'deprecated', []) yield " @property" yield " def %s(self):" % pname yield ' """' yield bi.wrap( "%s%s" % ("[Deprecated] " if deprecated else "", param.get('help')), indent=8 * ' ') # we need to wrap only for text coming from server yield "" yield bi.wrap(property_doc, indent=8 * ' ') custom_property_doc = get_customizations_for(algo, "doc.{}".format(pname)) if custom_property_doc: yield "" yield reformat_block(custom_property_doc, 8) property_examples = get_customizations_for(algo, "examples.{}".format(pname)) if property_examples: yield "" yield " :examples:" yield "" yield reformat_block(property_examples, 8) yield ' """' property_getter = get_customizations_for( algo, "overrides.{}.getter".format( pname)) # check gen_stackedensemble.py for an example if property_getter: yield reformat_block(property_getter.format(**locals()), 8) else: yield " return self._parms.get(\"%s\")" % sname yield "" yield " @%s.setter" % pname yield " def %s(self, %s):" % (pname, pname) property_setter = get_customizations_for( algo, "overrides.{}.setter".format( pname)) # check gen_stackedensemble.py for an example if property_setter: yield reformat_block(property_setter.format(**locals()), 8) else: # special types validation if ptype == "H2OEstimator": yield " assert_is_type(%s, None, str, %s)" % (pname, ptype) elif ptype == "H2OFrame": yield " self._parms[\"%s\"] = H2OFrame._validate(%s, '%s')" % ( sname, pname, pname) else: # default validation yield " assert_is_type(%s, None, %s)" % (pname, ptype) if ptype != "H2OFrame": # default assignment yield " self._parms[\"%s\"] = %s" % (sname, pname) yield "" yield "" if class_extras: yield reformat_block(code_as_str(class_extras), 4) if module_extras: yield "" yield reformat_block(code_as_str(module_extras))