def __init__( self, data_path, # for compatibility with other datasets API train=True, transform=None, download=True, cluster_params=None, n_samples=300, ): """Init Blobs dataset.""" super(CausalBlobs, self).__init__() self.root = data_path self.transform = transform if transform is not None else {} self.train = train # training set or test set self.n_samples = n_samples if cluster_params is None: self.cluster_params = dict( n_clusters=2, data_seed=0, radius=0.02, centers=None, proba_classes=0.5 ) else: self.cluster_params = cluster_params tmp_cluster_params = cluster_params.copy() if isinstance(cluster_params["centers"], np.ndarray): tmp_cluster_params["centers"] = tmp_cluster_params["centers"].tolist() cluster_hash = xp.param_to_hash(tmp_cluster_params) transform_hash = xp.param_to_hash(transform) self.data_dir = os.path.join(cluster_hash, transform_hash) root_dir = os.path.join(self.root, self.raw_folder) os.makedirs(root_dir, exist_ok=True) xp.record_hashes( os.path.join(root_dir, "parameters.json"), f"{cluster_hash}/{transform_hash}", {"cluster_params": tmp_cluster_params, "transform": transform}, ) self.training_file = "causal_blobs_train.pt" self.test_file = "causal_blobs_test.pt" self._cluster_gen = None if not self._check_exists() or download: self.create_on_disk() if not self._check_exists(): raise RuntimeError("Dataset not found.") if self.train: self.data, self.targets = torch.load( os.path.join( self.root, self.raw_folder, self.data_dir, self.training_file ) ) else: self.data, self.targets = torch.load( os.path.join(self.root, self.raw_folder, self.data_dir, self.test_file) )
test_params = network_params.copy() data_factory = DatasetFactory(data_params, data_path=args.data_path) output_dir = os.path.join(args.outdir, data_factory.get_data_short_name()) os.makedirs(output_dir, exist_ok=True) # parameters that change across experiments for the same dataset record_params = test_params.copy() record_params.update( { k: v for k, v in data_params.items() if k not in ("dataset_group", "dataset_name") } ) params_hash = xp.param_to_hash(record_params) hash_file = os.path.join(output_dir, "parameters.json") xp.record_hashes(hash_file, params_hash, record_params) output_file_prefix = os.path.join(output_dir, params_hash) test_csv_file = f"{output_file_prefix}.csv" checkpoint_dir = os.path.join(output_dir, "checkpoints", params_hash) results = xpr.XpResults.from_file( ["source acc", "target acc", "domain acc"], test_csv_file ) do_plots = False methods_variant_params = xp.load_json_dict(args.method) archis_res = {}
def run_loop(force_run=False): if auto_run or force_run: run_state.markdown( f":timer_clock: Running with settings: few={fewshot}, nseeds={nseeds}, methods={methods}." ) else: run_state.markdown( f":warning: Click button on the left to run training & evaluation." ) return data_factory = DatasetFactory(toy_params, data_path=args.data_path, n_fewshot=fewshot) os.makedirs(output_dir, exist_ok=True) test_hash = xp.param_to_hash(test_params) output_file_prefix = os.path.join(output_dir, test_hash) checkpoint_dir = os.path.join(output_dir, "checkpoints", test_hash) xp.record_hashes(os.path.join(args.outdir, "data_hashes.json"), data_hash, toy_params) xp.record_hashes(os.path.join(args.outdir, "test_hashes.json"), test_hash, test_params) test_csv_file = f"{output_file_prefix}.csv" results = xpr.XpResults.from_file( ["source acc", "target acc", "domain acc"], test_csv_file) archis_res = {} for m in methods: st.write(f"Learning {nseeds} x {m}") pgbar = st.progress(0) domain_archi = xp.loop_train_test_model( m, results, nseeds, test_csv_file, test_params=test_params, data_factory=data_factory, force_run=False, gpus=gpus, checkpoint_dir=checkpoint_dir, progress_callback=lambda percent: pgbar.progress(percent), ) if domain_archi is not None: archis_res[m] = domain_archi results.to_csv(test_csv_file) st.header("Results summary") st.write(f"Read from {test_csv_file}") st.dataframe(results.get_data().groupby(["method", "split"]).mean()) print(results.get_data()) fig = plt.figure() ax = sns.catplot(x="method", y="target acc", data=results.get_data(), kind="swarm", hue="split") plt.ylabel("Accuracy") st.pyplot() results.append_to_txt( filepath=os.path.join(output_dir, "all_res.txt"), test_params=test_params, nseeds=nseeds, ) st.header("Plots") logging.info("Recomputing context for best seed") mean_seed = results.get_mean_seed("target acc") archis_res = get_archi_or_file_for_seed( seed=mean_seed, methods=methods, file_prefix=output_file_prefix, test_params=test_params, data_factory=data_factory, checkpoint_dir=checkpoint_dir, gpus=gpus, fig_names=fig_names, ) output_file_prefix = f"{output_file_prefix}_{mean_seed}" from PIL import Image logging.getLogger("matplotlib").setLevel(logging.ERROR) all_figs = defaultdict(dict) for (name, res) in archis_res.items(): if isinstance(res, list): for fig_name in fig_names: for fig_file in res: if "_".join(fig_name.split()) in fig_file: all_figs[fig_name][name] = Image.open(fig_file) elif res is not None: figs = plot_archi_data( res, name, save_prefix=output_file_prefix, plot_f_lines="neurons" in fig_names, do_domain_boundary="domain boundary" in fig_names, plot_features=set(fig_names) & set(["PCA", "UMAP", "TSNE"]), do_entropy="entropy" in fig_names, ) for fig_name, fig in figs.items(): if fig_name in fig_names: all_figs[fig_name][name] = fig plt.close("all") else: st.write(f"Cannot show result for {name}: {res}") for fig_name, figs in all_figs.items(): st.subheader(fig_name) for meth_name, fig in figs.items(): st.write(meth_name) if isinstance(fig, Image.Image): st.image(fig, use_column_width=True) else: st.pyplot(fig) logging.info(f"See results with prefix {output_file_prefix}") run_state.text(f"Done, see results in {output_file_prefix}.")
def configure_dataset(default_dir, on_sidebar=True): stmod = st.sidebar if on_sidebar else st stmod.header("Dataset") json_files = glob.glob(f"{default_dir}/*.json") all_params_files = [(f, xp.load_json_dict(f)) for f in json_files] toy_files = [ f for (f, p) in all_params_files if p.get("dataset_group", "none") == "toy" ] dataset = stmod.selectbox("Dataset", toy_files, index=0) default_params = xp.load_json_dict(dataset) if on_sidebar: return default_params toy_params = deepcopy(default_params) # centers position default_centers = np.array([[-0.5, 0.0], [0.5, 0]]) param_centers = default_params["cluster"].get("centers", default_centers.tolist()) new_centers_st = stmod.text_input( "Position of class centers (source)", param_centers ) new_centers = json.loads(new_centers_st) n_clusters = len(new_centers) stmod.markdown(f"{n_clusters} classes.") toy_params["cluster"]["centers"] = new_centers toy_params["cluster"]["n_clusters"] = n_clusters # centers radii radius0 = default_params["cluster"]["radius"] same_radius = stmod.checkbox( "Use same variance everywhere (class/dimension)", value=isinstance(radius0, float), ) if same_radius: if not isinstance(radius0, float): radius0 = np.array(radius0).flatten()[0] radius = stmod.number_input( "Class variance", step=10 ** (np.floor(np.log10(radius0))), value=radius0, format="%.4f", ) toy_params["cluster"]["radius"] = radius else: if isinstance(radius0, float): radii = (np.ones_like(new_centers) * radius0).tolist() else: radii = radius0 new_radius_st = stmod.text_input( "Variance of each class along each dimension", radii ) new_radius = json.loads(new_radius_st) shape_variance = np.array(new_radius).shape shape_clusters = np.array(new_centers).shape if shape_variance == shape_clusters: stmod.markdown( ":heavy_check_mark: Shape of variance values matches the shape of clusters." ) else: stmod.markdown( ":warning: Warning: Shape of variances doesn't match the shape of clusters." ) toy_params["cluster"]["radius"] = new_radius # class balance proba_classes = default_params["cluster"]["proba_classes"] if n_clusters == 2: new_proba_classes = stmod.number_input( "Probability of class 1", step=10 ** (np.floor(np.log10(proba_classes))), value=proba_classes, format="%.4f", ) else: new_proba_classes_st = stmod.text_input( "Weight or probability of each class (will be normalized to sum to 1)", proba_classes, ) new_proba_classes = json.loads(new_proba_classes_st) nb_probas = len(new_proba_classes) if nb_probas == n_clusters: stmod.markdown( ":heavy_check_mark: class probas values matches the number of clusters." ) else: stmod.markdown( ":warning: Warning: class probas values don't match the number of clusters." ) toy_params["cluster"]["proba_classes"] = new_proba_classes # target shift default_cond_shift = default_params["shift"]["data_shift"] if n_clusters == 2: cond_shift = stmod.checkbox( "Class-conditional shift", value="cond" in default_cond_shift ) else: cond_shift = False if cond_shift: rotation0 = default_params["shift"]["re"] if isinstance(rotation0, float): default_r0 = rotation0 default_r1 = rotation0 else: default_r0 = default_params["shift"]["re"][0] default_r1 = default_params["shift"]["re"][1] re0 = stmod.slider( "Rotation class 0", min_value=-np.pi, max_value=np.pi, value=default_r0, ) re1 = stmod.slider( "Rotation class 1", min_value=-np.pi, max_value=np.pi, value=default_r1, ) transl0 = default_params["shift"]["te"] if isinstance(transl0, float): default_t0 = transl0 default_t1 = transl0 else: default_t0 = default_params["shift"]["te"][0] default_t1 = default_params["shift"]["te"][1] te0 = stmod.slider( "Translation class 0", min_value=-3.0, max_value=3.0, value=default_t0, ) te1 = stmod.slider( "Translation class 1", min_value=-3.0, max_value=3.0, value=default_t1, ) toy_params["shift"]["re"] = [re0, re1] toy_params["shift"]["te"] = [te0, te1] else: re = stmod.slider( "Rotation", min_value=-np.pi, max_value=np.pi, value=default_params["shift"]["re"], ) te = stmod.slider( "Translation", min_value=-3.0, max_value=3.0, value=default_params["shift"]["te"], ) toy_params["shift"]["re"] = re toy_params["shift"]["te"] = te test_view_data(toy_params) # choose a new (unique) name for the dataset and save data_hash = xp.param_to_hash(toy_params) default_hash = xp.param_to_hash(default_params) default_name = toy_params["dataset_name"] if default_hash != data_hash: if toy_params["dataset_name"] == default_params["dataset_name"]: default_name = data_hash data_name = st.text_input("Choose a (unique) name for your dataset", default_name) data_name = data_name.replace(" ", "_") toy_params["dataset_name"] = data_name data_file = os.path.join(default_dir, f"{data_name}.json") if os.path.exists(data_file): st.text(f"Data set with this name exists! {data_file}") else: if st.button("Save dataset"): with open(data_file, "w") as fd: fd.write(json.dumps(toy_params)) default_params = deepcopy(toy_params) st.text(f"Configuration saved to {data_file}") return toy_params, data_hash
def get_data_hash(self): return xp.param_to_hash(self._data_config)