def random_regression_datasets(n_samples, features=100, outs=1, informative=.1, partition_proportions=(.5, .3), rnd=None, **mk_rgr_kwargs): rnd_state = em.get_rand_state(rnd) X, Y, w = make_regression(n_samples, features, int(features * informative), outs, random_state=rnd_state, coef=True, **mk_rgr_kwargs) if outs == 1: Y = np.reshape(Y, (n_samples, 1)) print('range of Y', np.min(Y), np.max(Y)) info = utils.merge_dicts( { 'informative': informative, 'random_seed': rnd, 'w': w }, mk_rgr_kwargs) name = em.utils.name_from_dict(info, 'w') dt = em.Dataset(X, Y, name=name, info=info) datasets = em.Datasets.from_list(redivide_data([dt], partition_proportions)) print('conditioning of X^T X', np.linalg.cond(datasets.train.data.T @ datasets.train.data)) return datasets
def random_classification_datasets(n_samples, features=100, classes=2, informative=.1, partition_proportions=(.5, .3), rnd=None, one_hot=True, **mk_cls_kwargs): rnd_state = em.get_rand_state(rnd) X, Y = make_classification(n_samples, features, n_classes=classes, random_state=rnd_state, **mk_cls_kwargs) if one_hot: Y = utils.to_one_hot_enc(Y) print('range of Y', np.min(Y), np.max(Y)) info = utils.merge_dicts({ 'informative': informative, 'random_seed': rnd }, mk_cls_kwargs) name = em.utils.name_from_dict(info, 'w') dt = em.Dataset(X, Y, name=name, info=info) datasets = em.Datasets.from_list(redivide_data([dt], partition_proportions)) print('conditioning of X^T X', np.linalg.cond(datasets.train.data.T @ datasets.train.data)) return datasets
def _training_supplier(step=None): nonlocal other_feeds if step >= self.T: if step % self.T == 0: if self.epochs: print( 'WARNING: End of the training scheme reached.' 'Generating another scheme.', file=sys.stderr) self.generate_visiting_scheme() step %= self.T if self.training_schedule is None: # print('visiting scheme not yet generated!') self.generate_visiting_scheme() # noinspection PyTypeChecker nb = self.training_schedule[step * self.batch_size:min( (step + 1) * self.batch_size, len(self.training_schedule))] bx = self.dataset.data[nb, :] by = self.dataset.target[nb, :] # if lambda_feeds: # this was previous implementation... dunno for what it was used for # lambda_processed_feeds = {k: v(nb) for k, v in lambda_feeds.items()} previous implementation... # looks like lambda was # else: # lambda_processed_feeds = {} return merge_dicts({ x: bx, y: by }, *[maybe_call(of, step) for of in other_feeds])
def stack(*datasets): """ Assuming that the datasets have same structure, stacks data, targets and other info :param datasets: :return: stacked dataset """ return Dataset(data=vstack([d.data for d in datasets]), target=stack_or_concat([d.target for d in datasets]), sample_info=np.concatenate( [d.sample_info for d in datasets]), info={ k: [d.info.get(k, None) for d in datasets] for k in merge_dicts(*[d.info for d in datasets]) })
def autoplot(saver_or_history, saver=None, append_string='', clear_output=True, show_plots=True): if clear_output: try: c_out() except: pass # print(saver_or_history) if isinstance(saver_or_history, (list, tuple)): return [ autoplot(soh, saver, append_string, clear_output=False) for soh in saver_or_history ] if isinstance(saver_or_history, Saver): saver = saver_or_history history = saver.pack_save_dictionaries(erase_others=False, save_packed=False, append_string=append_string) if history is None: try: history = saver.load_obj('all__%s' % append_string) except FileNotFoundError: print('Packed object not found', file=sys.stderr) if history is None: return 'nothing yet...' else: history = saver_or_history # print(history) # noinspection PyBroadException def _simple_plot(_title, _label, _v): try: if isinstance(v, list): plt.title(_title.capitalize()) plt.plot(_v, label=_label.capitalize()) except: # TODO print the cause of exception. or at least the name print('Could not plot %s' % _title, file=sys.stderr) nest = defaultdict(lambda: {}) for k, v in history.items(): k_split = k.split('::') k_1 = k_split[1] if len(k_split) > 1 else '' nest[k_split[0]] = merge_dicts(nest[k_split[0]], {k_1: v}) for k, _dict_k in nest.items(): if k != 'SKIP' and k != 'HIDE': plt.figure(figsize=(8, 6)) for kk, v in _dict_k.items(): _simple_plot(k, kk, v) if all([kk for kk in _dict_k.keys()]): plt.legend(loc=0) if saver and saver.collect_data: saver.save_fig(append_string + '_' + k) saver.save_fig(append_string + '_' + k, extension='png') if show_plots: plt.show() plt.close() print('=' * 50) return 'done...'