def plot_state(self, state_idx, hl_rows=(), hl_cols=()): """ Visualize a cross-categorization state. .. note :: Work in progress This function is currently only suited to small data tables. There are problems with the row labels overlapping. There are problems with singleton views and categories having negative size or appearing as lines. Lots to fix. """ if self._subsampled: raise NotImplementedError("Not implemented for sub-sampled states") if hl_rows != (): if not isinstance(hl_rows, (list, np.ndarray,)): hl_rows = [hl_rows] hl_rows = [self._converters['row2idx'][row] for row in hl_rows] if hl_cols != (): if not isinstance(hl_cols, (list, np.ndarray,)): hl_cols = [hl_cols] hl_cols = [self._converters['col2idx'][col] for col in hl_cols] model = self._models[state_idx] init_kwargs = {'dtypes': self._dtypes, 'distargs': self._distargs, 'Zv': model['col_assignment'], 'Zrcv': model['row_assignments'], 'col_hypers': model['col_hypers'], 'state_alpha': model['state_alpha'], 'view_alphas': model['view_alphas']} state = BCState(self._data.T, **init_kwargs) model_logps = state.get_logps() pu.plot_cc_model(self._data, model, model_logps, self._df.index, self._df.columns, hl_rows=hl_rows, hl_cols=hl_cols)
def _initialize(args): data = args[0] kwargs = args[1] t_start = time.time() state = BCState(data.T, **kwargs) # transpose data to col-major metadata = state.get_metadata() diagnostics = { 'log_score': state.log_score(), 'iters': 0, 'time': time.time() - t_start} return metadata, [diagnostics]
def _initialize(args): data = args[0] kwargs = args[1] t_start = time.time() state = BCState(data.T, **kwargs) # transpose data to col-major metadata = state.get_metadata() diagnostics = { 'log_score': state.log_score(), 'iters': 0, 'time': time.time() - t_start } return metadata, [diagnostics]
def blank(): n_rows = 10 n_cols = 4 # column major X = np.random.rand(n_cols, n_rows) bcstate = BCState(X) return bcstate
def plot_state(self, state_idx, hl_rows=(), hl_cols=()): """ Visualize a cross-categorization state. .. note :: Work in progress This function is currently only suited to small data tables. There are problems with the row labels overlapping. There are problems with singleton views and categories having negative size or appearing as lines. Lots to fix. """ if self._subsampled: raise NotImplementedError("Not implemented for sub-sampled states") if hl_rows != (): if not isinstance(hl_rows, ( list, np.ndarray, )): hl_rows = [hl_rows] hl_rows = [self._converters['row2idx'][row] for row in hl_rows] if hl_cols != (): if not isinstance(hl_cols, ( list, np.ndarray, )): hl_cols = [hl_cols] hl_cols = [self._converters['col2idx'][col] for col in hl_cols] model = self._models[state_idx] init_kwargs = { 'dtypes': self._dtypes, 'distargs': self._distargs, 'Zv': model['col_assignment'], 'Zrcv': model['row_assignments'], 'col_hypers': model['col_hypers'], 'state_alpha': model['state_alpha'], 'view_alphas': model['view_alphas'] } state = BCState(self._data.T, **init_kwargs) model_logps = state.get_logps() pu.plot_cc_model(self._data, model, model_logps, self._df.index, self._df.columns, hl_rows=hl_rows, hl_cols=hl_cols)
def _run(args): data = args[0] checkpoint = args[1] t_id = args[2] verbose = args[3] init_kwargs = args[4] trans_kwargs = args[5] # create copy of trans_kwargs so we don't mutate trans_kwargs = dict(trans_kwargs) n_iter = trans_kwargs['N'] if checkpoint is None: checkpoint = n_iter n_sweeps = 1 else: trans_kwargs['N'] = checkpoint n_sweeps = int(n_iter/checkpoint) diagnostics = [] state = BCState(data.T, **init_kwargs) # transpose dat to col-major for i in range(n_sweeps): t_start = time.time() state.transition(**trans_kwargs) t_iter = time.time() - t_start n_views = state.n_views log_score = state.log_score() diagnostic = { 'log_score': log_score, 'n_views': n_views, 'iters': checkpoint, 'time': t_iter} diagnostics.append(diagnostic) if verbose: msg = "Model {}:\n\t+ sweep {} of {} in {} sec." msg += "\n\t+ log score: {}" msg += "\n\t+ n_views: {}\n" print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views)) metadata = state.get_metadata() return metadata, diagnostics
def _init_bcstates(self): self._bcstates = [] for model in self._models: sd = np.random.randint(2**31 - 1) init_kwarg = { 'dtypes': self._dtypes, 'distargs': self._distargs, 'Zv': model['col_assignment'], 'Zrcv': model['row_assignments'], 'col_hypers': model['col_hypers'], 'state_alpha': model['state_alpha'], 'view_alphas': model['view_alphas'], 'seed': sd } self._bcstates.append(BCState(self._data.T, **init_kwarg))
def _run(args): data = args[0] checkpoint = args[1] t_id = args[2] verbose = args[3] init_kwargs = args[4] trans_kwargs = args[5] # create copy of trans_kwargs so we don't mutate trans_kwargs = dict(trans_kwargs) n_iter = trans_kwargs['N'] if checkpoint is None: checkpoint = n_iter n_sweeps = 1 else: trans_kwargs['N'] = checkpoint n_sweeps = int(n_iter / checkpoint) diagnostics = [] state = BCState(data.T, **init_kwargs) # transpose dat to col-major for i in range(n_sweeps): t_start = time.time() state.transition(**trans_kwargs) t_iter = time.time() - t_start n_views = state.n_views log_score = state.log_score() diagnostic = { 'log_score': log_score, 'n_views': n_views, 'iters': checkpoint, 'time': t_iter } diagnostics.append(diagnostic) if verbose: msg = "Model {}:\n\t+ sweep {} of {} in {} sec." msg += "\n\t+ log score: {}" msg += "\n\t+ n_views: {}\n" print(msg.format(t_id, i, n_sweeps, t_iter, log_score, n_views)) metadata = state.get_metadata() return metadata, diagnostics
def test_transition(gdat): X, _, _ = gdat bcstate = BCState(X) bcstate.transition()
def test_predictive_logp(gdat): X, _, _ = gdat bcstate = BCState(X) res = bcstate.predictive_logp([[2, 2]], [1.2], [[1, 3]], [0.0]) res = bcstate.predictive_logp([[2, 2]], [1.2]) assert res is not None
def test_predictive_draw(gdat): X, _, _ = gdat bcstate = BCState(X) res = bcstate.predictive_draw([[2, 2]], [[1, 3]], [4.1]) res = bcstate.predictive_draw([[2, 2]]) assert res is not None
def test_get_metadata(gdat): X, _, _ = gdat bcstate = BCState(X) res = bcstate.get_metadata() assert res is not None
def test_initalizes_no_args(gdat): X, _, _ = gdat bcstate = BCState(X) assert bcstate is not None
def test_initalizes_base(gdat): X, datatypes_list, distargs = gdat bcstate = BCState(X, datatypes_list, distargs) assert bcstate is not None
def test_predictive_probability(gdat): X, _, _ = gdat bcstate = BCState(X) res = bcstate.predictive_probability([[2, 2]], [1.2], [[1, 3]], [0.0]) res = bcstate.predictive_probability([[2, 2]], [1.2]) assert res is not None
def test_conditioned_row_resample(gdat): X, _, _ = gdat bcstate = BCState(X) bcstate.conditioned_row_resample(2, lambda x: 0.0)