예제 #1
0
    def __init__(self, defn, views, latent, kernel_config):
        validator.validate_type(defn, model_definition, 'defn')
        validator.validate_len(views, len(defn.relations()), 'views')
        for view in views:
            validator.validate_type(view, abstract_dataview)
        validator.validate_type(latent, state, 'latent')

        self._defn = defn
        self._views = views
        self._latent = copy.deepcopy(latent)

        self._kernel_config = []
        for kernel in kernel_config:
            name, config = kernel

            if not hasattr(config, 'iteritems'):
                config = {c: {} for c in config}
            validator.validate_dict_like(config)

            def require_relation_keys(config):
                valid_keys = set(xrange(len(defn.relations())))
                if not set(config.keys()).issubset(valid_keys):
                    raise ValueError("bad config found: {}".format(config))

            def require_domain_keys(config):
                valid_keys = set(xrange(len(defn.domains())))
                if not set(config.keys()).issubset(valid_keys):
                    raise ValueError("bad config found: {}".format(config))

            if name == 'assign':
                require_domain_keys(config)
                for v in config.values():
                    validator.validate_dict_like(v)
                    if v:
                        msg = "assign has no config params: {}".format(v)
                        raise ValueError(msg)

            elif name == 'assign_resample':
                require_domain_keys(config)
                for v in config.values():
                    validator.validate_dict_like(v)
                    if v.keys() != ['m']:
                        raise ValueError("bad config found: {}".format(v))

            elif name == 'slice_cluster_hp':
                require_domain_keys(config)
                for v in config.values():
                    validator.validate_dict_like(v)
                    if v.keys() != ['cparam']:
                        raise ValueError("bad config found: {}".format(v))

            elif name == 'grid_relation_hp':
                require_relation_keys(config)
                for ri, ps in config.iteritems():
                    if set(ps.keys()) != set(('hpdf', 'hgrid',)):
                        raise ValueError("bad config found: {}".format(ps))
                    full = []
                    for partial in ps['hgrid']:
                        hp = latent.get_relation_hp(ri)
                        hp.update(partial)
                        full.append(hp)
                    ps['hgrid'] = full

            elif name == 'slice_relation_hp':
                if config.keys() != ['hparams']:
                    raise ValueError("bad config found: {}".format(config))
                validator.validate_dict_like(config['hparams'])
                require_relation_keys(config['hparams'])

            elif name == 'theta':
                if config.keys() != ['tparams']:
                    raise ValueError("bad config found: {}".format(config))
                validator.validate_dict_like(config['tparams'])
                require_relation_keys(config['tparams'])

            else:
                raise ValueError("bad kernel found: {}".format(name))

            self._kernel_config.append((name, config))
예제 #2
0
def test_validate_len():
    obj = [1]
    V.validate_len(obj, 1)
    assert_raises(ValueError, V.validate_len, obj, 2)
예제 #3
0
    def __init__(self, defn, view, latent, kernel_config):
        defn = _validate_definition(defn)
        validator.validate_type(view, abstract_dataview, param_name='view')
        if not isinstance(latent, state):
            raise ValueError("bad latent given")
        validator.validate_len(view, defn.n())

        def require_feature_indices(v):
            nfeatures = len(defn.models())
            valid_keys = set(xrange(nfeatures))
            if not set(v.keys()).issubset(valid_keys):
                msg = "bad config found: {}".format(v)
                raise ValueError(msg)

        self._defn = defn
        self._view = view
        self._latent = copy.deepcopy(latent)

        self._kernel_config = []
        for kernel in kernel_config:

            if hasattr(kernel, '__iter__'):
                name, config = kernel
            else:
                name, config = kernel, {}
            validator.validate_dict_like(config)

            if name == 'assign':
                if config:
                    raise ValueError("assign has no parameters")

            elif name == 'assign_resample':
                if config.keys() != ['m']:
                    raise ValueError("bad config found: {}".format(config))
                validator.validate_positive(config['m'])

            elif name == 'grid_feature_hp':
                require_feature_indices(config)
                for fi, ps in config.iteritems():
                    if set(ps.keys()) != set(('hpdf', 'hgrid',)):
                        raise ValueError("bad config found: {}".format(ps))
                    full = []
                    for partial in ps['hgrid']:
                        hp = latent.get_feature_hp(fi)
                        hp.update(partial)
                        full.append(hp)
                    ps['hgrid'] = full

            elif name == 'slice_feature_hp':
                if config.keys() != ['hparams']:
                    raise ValueError("bad config found: {}".format(config))
                require_feature_indices(config['hparams'])

            elif name == 'slice_cluster_hp':
                if config.keys() != ['cparam']:
                    raise ValueError("bad config found: {}".format(config))
                if config['cparam'].keys() != ['alpha']:
                    msg = "bad config found: {}".format(config['cparam'])
                    raise ValueError(msg)

            elif name == 'theta':
                if config.keys() != ['tparams']:
                    raise ValueError("bad config found: {}".format(config))
                require_feature_indices(config['tparams'])

            else:
                raise ValueError("bad kernel found: {}".format(name))

            self._kernel_config.append((name, config))
예제 #4
0
def test_validate_len():
    obj = [1]
    V.validate_len(obj, 1)
    assert_raises(ValueError, V.validate_len, obj, 2)