Пример #1
0
    def plot_state(self, state_idx, hl_rows=(), hl_cols=()):
        """ Visualize a cross-categorization state.

        .. note :: Work in progress
            This function is currently only suited to small data tables. There
            are problems with the row labels overlapping. There are problems
            with singleton views and categories having negative size or
            appearing as lines. Lots to fix.
        """
        if self._subsampled:
            raise NotImplementedError("Not implemented for sub-sampled states")

        if hl_rows != ():
            if not isinstance(hl_rows, (list, np.ndarray,)):
                hl_rows = [hl_rows]
            hl_rows = [self._converters['row2idx'][row] for row in hl_rows]

        if hl_cols != ():
            if not isinstance(hl_cols, (list, np.ndarray,)):
                hl_cols = [hl_cols]
            hl_cols = [self._converters['col2idx'][col] for col in hl_cols]

        model = self._models[state_idx]
        init_kwargs = {'dtypes': self._dtypes,
                       'distargs': self._distargs,
                       'Zv': model['col_assignment'],
                       'Zrcv': model['row_assignments'],
                       'col_hypers': model['col_hypers'],
                       'state_alpha': model['state_alpha'],
                       'view_alphas': model['view_alphas']}
        state = BCState(self._data.T, **init_kwargs)
        model_logps = state.get_logps()
        pu.plot_cc_model(self._data, model, model_logps, self._df.index,
                         self._df.columns, hl_rows=hl_rows, hl_cols=hl_cols)
Пример #2
0
def _initialize(args):
    data = args[0]
    kwargs = args[1]

    t_start = time.time()
    state = BCState(data.T, **kwargs)  # transpose data to col-major

    metadata = state.get_metadata()
    diagnostics = {
        'log_score': state.log_score(),
        'iters': 0,
        'time': time.time() - t_start}

    return metadata, [diagnostics]
Пример #3
0
def _initialize(args):
    data = args[0]
    kwargs = args[1]

    t_start = time.time()
    state = BCState(data.T, **kwargs)  # transpose data to col-major

    metadata = state.get_metadata()
    diagnostics = {
        'log_score': state.log_score(),
        'iters': 0,
        'time': time.time() - t_start
    }

    return metadata, [diagnostics]
Пример #4
0
def blank():
    n_rows = 10
    n_cols = 4
    # column major
    X = np.random.rand(n_cols, n_rows)
    bcstate = BCState(X)
    return bcstate
Пример #5
0
    def plot_state(self, state_idx, hl_rows=(), hl_cols=()):
        """ Visualize a cross-categorization state.

        .. note :: Work in progress
            This function is currently only suited to small data tables. There
            are problems with the row labels overlapping. There are problems
            with singleton views and categories having negative size or
            appearing as lines. Lots to fix.
        """
        if self._subsampled:
            raise NotImplementedError("Not implemented for sub-sampled states")

        if hl_rows != ():
            if not isinstance(hl_rows, (
                    list,
                    np.ndarray,
            )):
                hl_rows = [hl_rows]
            hl_rows = [self._converters['row2idx'][row] for row in hl_rows]

        if hl_cols != ():
            if not isinstance(hl_cols, (
                    list,
                    np.ndarray,
            )):
                hl_cols = [hl_cols]
            hl_cols = [self._converters['col2idx'][col] for col in hl_cols]

        model = self._models[state_idx]
        init_kwargs = {
            'dtypes': self._dtypes,
            'distargs': self._distargs,
            'Zv': model['col_assignment'],
            'Zrcv': model['row_assignments'],
            'col_hypers': model['col_hypers'],
            'state_alpha': model['state_alpha'],
            'view_alphas': model['view_alphas']
        }
        state = BCState(self._data.T, **init_kwargs)
        model_logps = state.get_logps()
        pu.plot_cc_model(self._data,
                         model,
                         model_logps,
                         self._df.index,
                         self._df.columns,
                         hl_rows=hl_rows,
                         hl_cols=hl_cols)
Пример #6
0
def _run(args):
    data = args[0]
    checkpoint = args[1]
    t_id = args[2]
    verbose = args[3]
    init_kwargs = args[4]
    trans_kwargs = args[5]

    # create copy of trans_kwargs so we don't mutate
    trans_kwargs = dict(trans_kwargs)
    n_iter = trans_kwargs['N']
    if checkpoint is None:
        checkpoint = n_iter
        n_sweeps = 1
    else:
        trans_kwargs['N'] = checkpoint
        n_sweeps = int(n_iter/checkpoint)

    diagnostics = []
    state = BCState(data.T, **init_kwargs)  # transpose dat to col-major
    for i in range(n_sweeps):
        t_start = time.time()
        state.transition(**trans_kwargs)
        t_iter = time.time() - t_start

        n_views = state.n_views
        log_score = state.log_score()

        diagnostic = {
            'log_score': log_score,
            'n_views': n_views,
            'iters': checkpoint,
            'time': t_iter}
        diagnostics.append(diagnostic)

        if verbose:
            msg = "Model {}:\n\t+ sweep {} of {} in {} sec."
            msg += "\n\t+ log score: {}"
            msg += "\n\t+ n_views: {}\n"

            print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views))

    metadata = state.get_metadata()

    return metadata, diagnostics
Пример #7
0
 def _init_bcstates(self):
     self._bcstates = []
     for model in self._models:
         sd = np.random.randint(2**31 - 1)
         init_kwarg = {
             'dtypes': self._dtypes,
             'distargs': self._distargs,
             'Zv': model['col_assignment'],
             'Zrcv': model['row_assignments'],
             'col_hypers': model['col_hypers'],
             'state_alpha': model['state_alpha'],
             'view_alphas': model['view_alphas'],
             'seed': sd
         }
         self._bcstates.append(BCState(self._data.T, **init_kwarg))
Пример #8
0
def _run(args):
    data = args[0]
    checkpoint = args[1]
    t_id = args[2]
    verbose = args[3]
    init_kwargs = args[4]
    trans_kwargs = args[5]

    # create copy of trans_kwargs so we don't mutate
    trans_kwargs = dict(trans_kwargs)
    n_iter = trans_kwargs['N']
    if checkpoint is None:
        checkpoint = n_iter
        n_sweeps = 1
    else:
        trans_kwargs['N'] = checkpoint
        n_sweeps = int(n_iter / checkpoint)

    diagnostics = []
    state = BCState(data.T, **init_kwargs)  # transpose dat to col-major
    for i in range(n_sweeps):
        t_start = time.time()
        state.transition(**trans_kwargs)
        t_iter = time.time() - t_start

        n_views = state.n_views
        log_score = state.log_score()

        diagnostic = {
            'log_score': log_score,
            'n_views': n_views,
            'iters': checkpoint,
            'time': t_iter
        }
        diagnostics.append(diagnostic)

        if verbose:
            msg = "Model {}:\n\t+ sweep {} of {} in {} sec."
            msg += "\n\t+ log score: {}"
            msg += "\n\t+ n_views: {}\n"

            print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views))

    metadata = state.get_metadata()

    return metadata, diagnostics
Пример #9
0
def test_transition(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.transition()
Пример #10
0
def test_predictive_logp(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_logp([[2, 2]], [1.2], [[1, 3]], [0.0])
    res = bcstate.predictive_logp([[2, 2]], [1.2])
    assert res is not None
Пример #11
0
def test_predictive_draw(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_draw([[2, 2]], [[1, 3]], [4.1])
    res = bcstate.predictive_draw([[2, 2]])
    assert res is not None
Пример #12
0
def test_get_metadata(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.get_metadata()
    assert res is not None
Пример #13
0
def test_transition(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.transition()
Пример #14
0
def test_initalizes_no_args(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    assert bcstate is not None
Пример #15
0
def test_initalizes_base(gdat):
    X, datatypes_list, distargs = gdat
    bcstate = BCState(X, datatypes_list, distargs)
    assert bcstate is not None
Пример #16
0
def test_get_metadata(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.get_metadata()
    assert res is not None
Пример #17
0
def test_predictive_draw(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_draw([[2, 2]], [[1, 3]], [4.1])
    res = bcstate.predictive_draw([[2, 2]])
    assert res is not None
Пример #18
0
def test_predictive_probability(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_probability([[2, 2]], [1.2], [[1, 3]], [0.0])
    res = bcstate.predictive_probability([[2, 2]], [1.2])
    assert res is not None
Пример #19
0
def test_conditioned_row_resample(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.conditioned_row_resample(2, lambda x: 0.0)