示例#1
0
    def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True):

        if self.df is None:
            self.df = pd.DataFrame()
            balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets"
            self.df._metadata.append(balance_t)
            if balance_t=="class_targets":
                mapper = lambda cl_id: self.cf.class_id2label[cl_id]
                labels = self.cf.class_id2label.values()
            elif balance_t=="rg_bin_targets":
                mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin]
                labels = self.cf.bin_id2label.values()
            # elif balance_t=="regression_targets":
            #     # todo this wont work
            #     mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)]
            #     labels = self.cf.bin_id2label.values()
            elif balance_t=="lesion_gleasons":
                mapper = lambda gs: self.cf.gs2label[gs]
                labels = self.cf.gs2label.values()
            else:
                mapper = lambda x: AttributeDict({"name":x})
                labels = None
            for pid, subj_data in self.data.items():
                unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True)
                self.df = self.df.append(pd.DataFrame({"pid": [pid],
                                                       **{mapper(unique_ts[i]).name: [counts[i]] for i in
                                                          range(len(unique_ts))}}), ignore_index=True, sort=True)
            self.df = self.df.fillna(0)

        if overall_stats:
            df = self.df.drop("pid", axis=1)
            df = df.reindex(sorted(df.columns), axis=1).astype('uint32')
            print("Overall dataset roi counts per target kind:"); print(df.sum())
        if subsets is not None:
            self.df["subset"] = np.nan
            self.df["display_order"] = np.nan
            for ix, (subset, pids) in enumerate(subsets.items()):
                self.df.loc[self.df.pid.isin(pids), "subset"] = subset
                self.df.loc[self.df.pid.isin(pids), "display_order"] = ix
            df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64')
            df = df.sort_values(by=['display_order']).drop('display_order', axis=1)
            df = df.reindex(sorted(df.columns), axis=1)

            print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df)
        if plot_dir is not None:
            os.makedirs(plot_dir, exist_ok=True)
            if subsets is not None:
                plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf")
            if overall_stats:
                plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf'))

        return df, labels
data_dict = utils.read_file(data_path)
do_plot = False
if do_plot:
    plotting.plot_data_pca(data_dict)
    plt.savefig(os.path.join(figure_dir, 'data_pca.png'))
    plotting.plot_data_example(data_dict['inputs'], data_dict['hiddens'],
                               data_dict['outputs'], data_dict['targets'])
    plt.savefig(os.path.join(figure_dir, 'data_example.png'))

# Data was generated w/ VRNN w/ tanh, thus (data+1) / 2 -> [0,1]
data_bxtxn = utils.spikify_data((data_dict['hiddens'] + 1) / 2,
                                onp_rng,
                                data_dt,
                                max_firing_rate=max_firing_rate)
if do_plot:
    plotting.plot_data_stats(data_dict, data_bxtxn, data_dt)
    plt.savefig(os.path.join(figure_dir, 'data_stats.png'))
train_data, eval_data = utils.split_data(data_bxtxn, train_fraction=0.9)

### LFADS Hyper parameters
data_dim = train_data.shape[2]  # input to lfads should have dimensions:
ntimesteps = train_data.shape[1]  #   (batch_size x ntimesteps x data_dim)
batch_size = 128  # batch size during optimization

# LFADS architecture
enc_dim = 64  # encoder dim
con_dim = 64  # contoller dim
ii_dim = 1  # inferred input dim
gen_dim = 75  # generator dim
factors_dim = 20  # factors dim