Example #1
0
    def plot_state(self, state_idx, hl_rows=(), hl_cols=()):
        """ Visualize a cross-categorization state.

        .. note :: Work in progress
            This function is currently only suited to small data tables. There
            are problems with the row labels overlapping. There are problems
            with singleton views and categories having negative size or
            appearing as lines. Lots to fix.
        """
        if self._subsampled:
            raise NotImplementedError("Not implemented for sub-sampled states")

        if hl_rows != ():
            if not isinstance(hl_rows, (list, np.ndarray,)):
                hl_rows = [hl_rows]
            hl_rows = [self._converters['row2idx'][row] for row in hl_rows]

        if hl_cols != ():
            if not isinstance(hl_cols, (list, np.ndarray,)):
                hl_cols = [hl_cols]
            hl_cols = [self._converters['col2idx'][col] for col in hl_cols]

        model = self._models[state_idx]
        init_kwargs = {'dtypes': self._dtypes,
                       'distargs': self._distargs,
                       'Zv': model['col_assignment'],
                       'Zrcv': model['row_assignments'],
                       'col_hypers': model['col_hypers'],
                       'state_alpha': model['state_alpha'],
                       'view_alphas': model['view_alphas']}
        state = BCState(self._data.T, **init_kwargs)
        model_logps = state.get_logps()
        pu.plot_cc_model(self._data, model, model_logps, self._df.index,
                         self._df.columns, hl_rows=hl_rows, hl_cols=hl_cols)
Example #2
0
def _initialize(args):
    data = args[0]
    kwargs = args[1]

    t_start = time.time()
    state = BCState(data.T, **kwargs)  # transpose data to col-major

    metadata = state.get_metadata()
    diagnostics = {
        'log_score': state.log_score(),
        'iters': 0,
        'time': time.time() - t_start}

    return metadata, [diagnostics]
Example #3
0
def _initialize(args):
    data = args[0]
    kwargs = args[1]

    t_start = time.time()
    state = BCState(data.T, **kwargs)  # transpose data to col-major

    metadata = state.get_metadata()
    diagnostics = {
        'log_score': state.log_score(),
        'iters': 0,
        'time': time.time() - t_start
    }

    return metadata, [diagnostics]
Example #4
0
def blank():
    n_rows = 10
    n_cols = 4
    # column major
    X = np.random.rand(n_cols, n_rows)
    bcstate = BCState(X)
    return bcstate
Example #5
0
    def plot_state(self, state_idx, hl_rows=(), hl_cols=()):
        """ Visualize a cross-categorization state.

        .. note :: Work in progress
            This function is currently only suited to small data tables. There
            are problems with the row labels overlapping. There are problems
            with singleton views and categories having negative size or
            appearing as lines. Lots to fix.
        """
        if self._subsampled:
            raise NotImplementedError("Not implemented for sub-sampled states")

        if hl_rows != ():
            if not isinstance(hl_rows, (
                    list,
                    np.ndarray,
            )):
                hl_rows = [hl_rows]
            hl_rows = [self._converters['row2idx'][row] for row in hl_rows]

        if hl_cols != ():
            if not isinstance(hl_cols, (
                    list,
                    np.ndarray,
            )):
                hl_cols = [hl_cols]
            hl_cols = [self._converters['col2idx'][col] for col in hl_cols]

        model = self._models[state_idx]
        init_kwargs = {
            'dtypes': self._dtypes,
            'distargs': self._distargs,
            'Zv': model['col_assignment'],
            'Zrcv': model['row_assignments'],
            'col_hypers': model['col_hypers'],
            'state_alpha': model['state_alpha'],
            'view_alphas': model['view_alphas']
        }
        state = BCState(self._data.T, **init_kwargs)
        model_logps = state.get_logps()
        pu.plot_cc_model(self._data,
                         model,
                         model_logps,
                         self._df.index,
                         self._df.columns,
                         hl_rows=hl_rows,
                         hl_cols=hl_cols)
Example #6
0
def _run(args):
    data = args[0]
    checkpoint = args[1]
    t_id = args[2]
    verbose = args[3]
    init_kwargs = args[4]
    trans_kwargs = args[5]

    # create copy of trans_kwargs so we don't mutate
    trans_kwargs = dict(trans_kwargs)
    n_iter = trans_kwargs['N']
    if checkpoint is None:
        checkpoint = n_iter
        n_sweeps = 1
    else:
        trans_kwargs['N'] = checkpoint
        n_sweeps = int(n_iter/checkpoint)

    diagnostics = []
    state = BCState(data.T, **init_kwargs)  # transpose dat to col-major
    for i in range(n_sweeps):
        t_start = time.time()
        state.transition(**trans_kwargs)
        t_iter = time.time() - t_start

        n_views = state.n_views
        log_score = state.log_score()

        diagnostic = {
            'log_score': log_score,
            'n_views': n_views,
            'iters': checkpoint,
            'time': t_iter}
        diagnostics.append(diagnostic)

        if verbose:
            msg = "Model {}:\n\t+ sweep {} of {} in {} sec."
            msg += "\n\t+ log score: {}"
            msg += "\n\t+ n_views: {}\n"

            print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views))

    metadata = state.get_metadata()

    return metadata, diagnostics
Example #7
0
 def _init_bcstates(self):
     self._bcstates = []
     for model in self._models:
         sd = np.random.randint(2**31 - 1)
         init_kwarg = {
             'dtypes': self._dtypes,
             'distargs': self._distargs,
             'Zv': model['col_assignment'],
             'Zrcv': model['row_assignments'],
             'col_hypers': model['col_hypers'],
             'state_alpha': model['state_alpha'],
             'view_alphas': model['view_alphas'],
             'seed': sd
         }
         self._bcstates.append(BCState(self._data.T, **init_kwarg))
Example #8
0
def _run(args):
    data = args[0]
    checkpoint = args[1]
    t_id = args[2]
    verbose = args[3]
    init_kwargs = args[4]
    trans_kwargs = args[5]

    # create copy of trans_kwargs so we don't mutate
    trans_kwargs = dict(trans_kwargs)
    n_iter = trans_kwargs['N']
    if checkpoint is None:
        checkpoint = n_iter
        n_sweeps = 1
    else:
        trans_kwargs['N'] = checkpoint
        n_sweeps = int(n_iter / checkpoint)

    diagnostics = []
    state = BCState(data.T, **init_kwargs)  # transpose dat to col-major
    for i in range(n_sweeps):
        t_start = time.time()
        state.transition(**trans_kwargs)
        t_iter = time.time() - t_start

        n_views = state.n_views
        log_score = state.log_score()

        diagnostic = {
            'log_score': log_score,
            'n_views': n_views,
            'iters': checkpoint,
            'time': t_iter
        }
        diagnostics.append(diagnostic)

        if verbose:
            msg = "Model {}:\n\t+ sweep {} of {} in {} sec."
            msg += "\n\t+ log score: {}"
            msg += "\n\t+ n_views: {}\n"

            print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views))

    metadata = state.get_metadata()

    return metadata, diagnostics
Example #9
0
def test_transition(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.transition()
Example #10
0
def test_predictive_logp(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_logp([[2, 2]], [1.2], [[1, 3]], [0.0])
    res = bcstate.predictive_logp([[2, 2]], [1.2])
    assert res is not None
Example #11
0
def test_predictive_draw(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_draw([[2, 2]], [[1, 3]], [4.1])
    res = bcstate.predictive_draw([[2, 2]])
    assert res is not None
Example #12
0
def test_get_metadata(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.get_metadata()
    assert res is not None
Example #13
0
def test_transition(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.transition()
Example #14
0
def test_initalizes_no_args(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    assert bcstate is not None
Example #15
0
def test_initalizes_base(gdat):
    X, datatypes_list, distargs = gdat
    bcstate = BCState(X, datatypes_list, distargs)
    assert bcstate is not None
Example #16
0
def test_get_metadata(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.get_metadata()
    assert res is not None
Example #17
0
def test_predictive_draw(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_draw([[2, 2]], [[1, 3]], [4.1])
    res = bcstate.predictive_draw([[2, 2]])
    assert res is not None
Example #18
0
def test_predictive_probability(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    res = bcstate.predictive_probability([[2, 2]], [1.2], [[1, 3]], [0.0])
    res = bcstate.predictive_probability([[2, 2]], [1.2])
    assert res is not None
Example #19
0
def test_conditioned_row_resample(gdat):
    X, _, _ = gdat
    bcstate = BCState(X)
    bcstate.conditioned_row_resample(2, lambda x: 0.0)