Exemplo n.º 1
0
    def __init__(
        self,
        data_path,  # for compatibility with other datasets API
        train=True,
        transform=None,
        download=True,
        cluster_params=None,
        n_samples=300,
    ):
        """Init Blobs dataset."""
        super(CausalBlobs, self).__init__()
        self.root = data_path
        self.transform = transform if transform is not None else {}
        self.train = train  # training set or test set
        self.n_samples = n_samples

        if cluster_params is None:
            self.cluster_params = dict(
                n_clusters=2, data_seed=0, radius=0.02, centers=None, proba_classes=0.5
            )
        else:
            self.cluster_params = cluster_params

        tmp_cluster_params = cluster_params.copy()
        if isinstance(cluster_params["centers"], np.ndarray):
            tmp_cluster_params["centers"] = tmp_cluster_params["centers"].tolist()
        cluster_hash = xp.param_to_hash(tmp_cluster_params)
        transform_hash = xp.param_to_hash(transform)
        self.data_dir = os.path.join(cluster_hash, transform_hash)
        root_dir = os.path.join(self.root, self.raw_folder)
        os.makedirs(root_dir, exist_ok=True)
        xp.record_hashes(
            os.path.join(root_dir, "parameters.json"),
            f"{cluster_hash}/{transform_hash}",
            {"cluster_params": tmp_cluster_params, "transform": transform},
        )

        self.training_file = "causal_blobs_train.pt"
        self.test_file = "causal_blobs_test.pt"
        self._cluster_gen = None

        if not self._check_exists() or download:
            self.create_on_disk()

        if not self._check_exists():
            raise RuntimeError("Dataset not found.")

        if self.train:
            self.data, self.targets = torch.load(
                os.path.join(
                    self.root, self.raw_folder, self.data_dir, self.training_file
                )
            )
        else:
            self.data, self.targets = torch.load(
                os.path.join(self.root, self.raw_folder, self.data_dir, self.test_file)
            )
Exemplo n.º 2
0
    test_params = network_params.copy()

    data_factory = DatasetFactory(data_params, data_path=args.data_path)
    output_dir = os.path.join(args.outdir, data_factory.get_data_short_name())

    os.makedirs(output_dir, exist_ok=True)
    # parameters that change across experiments for the same dataset
    record_params = test_params.copy()
    record_params.update(
        {
            k: v
            for k, v in data_params.items()
            if k not in ("dataset_group", "dataset_name")
        }
    )
    params_hash = xp.param_to_hash(record_params)
    hash_file = os.path.join(output_dir, "parameters.json")
    xp.record_hashes(hash_file, params_hash, record_params)
    output_file_prefix = os.path.join(output_dir, params_hash)

    test_csv_file = f"{output_file_prefix}.csv"
    checkpoint_dir = os.path.join(output_dir, "checkpoints", params_hash)

    results = xpr.XpResults.from_file(
        ["source acc", "target acc", "domain acc"], test_csv_file
    )
    do_plots = False

    methods_variant_params = xp.load_json_dict(args.method)

    archis_res = {}
Exemplo n.º 3
0
def run_loop(force_run=False):
    if auto_run or force_run:
        run_state.markdown(
            f":timer_clock: Running with settings: few={fewshot}, nseeds={nseeds}, methods={methods}."
        )
    else:
        run_state.markdown(
            f":warning: Click button on the left to run training & evaluation."
        )
        return

    data_factory = DatasetFactory(toy_params,
                                  data_path=args.data_path,
                                  n_fewshot=fewshot)
    os.makedirs(output_dir, exist_ok=True)
    test_hash = xp.param_to_hash(test_params)
    output_file_prefix = os.path.join(output_dir, test_hash)

    checkpoint_dir = os.path.join(output_dir, "checkpoints", test_hash)

    xp.record_hashes(os.path.join(args.outdir, "data_hashes.json"), data_hash,
                     toy_params)
    xp.record_hashes(os.path.join(args.outdir, "test_hashes.json"), test_hash,
                     test_params)

    test_csv_file = f"{output_file_prefix}.csv"
    results = xpr.XpResults.from_file(
        ["source acc", "target acc", "domain acc"], test_csv_file)

    archis_res = {}
    for m in methods:
        st.write(f"Learning {nseeds} x {m}")
        pgbar = st.progress(0)
        domain_archi = xp.loop_train_test_model(
            m,
            results,
            nseeds,
            test_csv_file,
            test_params=test_params,
            data_factory=data_factory,
            force_run=False,
            gpus=gpus,
            checkpoint_dir=checkpoint_dir,
            progress_callback=lambda percent: pgbar.progress(percent),
        )
        if domain_archi is not None:
            archis_res[m] = domain_archi
        results.to_csv(test_csv_file)

    st.header("Results summary")
    st.write(f"Read from {test_csv_file}")
    st.dataframe(results.get_data().groupby(["method", "split"]).mean())
    print(results.get_data())
    fig = plt.figure()
    ax = sns.catplot(x="method",
                     y="target acc",
                     data=results.get_data(),
                     kind="swarm",
                     hue="split")
    plt.ylabel("Accuracy")
    st.pyplot()

    results.append_to_txt(
        filepath=os.path.join(output_dir, "all_res.txt"),
        test_params=test_params,
        nseeds=nseeds,
    )
    st.header("Plots")

    logging.info("Recomputing context for best seed")

    mean_seed = results.get_mean_seed("target acc")
    archis_res = get_archi_or_file_for_seed(
        seed=mean_seed,
        methods=methods,
        file_prefix=output_file_prefix,
        test_params=test_params,
        data_factory=data_factory,
        checkpoint_dir=checkpoint_dir,
        gpus=gpus,
        fig_names=fig_names,
    )
    output_file_prefix = f"{output_file_prefix}_{mean_seed}"

    from PIL import Image

    logging.getLogger("matplotlib").setLevel(logging.ERROR)
    all_figs = defaultdict(dict)
    for (name, res) in archis_res.items():
        if isinstance(res, list):
            for fig_name in fig_names:
                for fig_file in res:
                    if "_".join(fig_name.split()) in fig_file:
                        all_figs[fig_name][name] = Image.open(fig_file)
        elif res is not None:
            figs = plot_archi_data(
                res,
                name,
                save_prefix=output_file_prefix,
                plot_f_lines="neurons" in fig_names,
                do_domain_boundary="domain boundary" in fig_names,
                plot_features=set(fig_names) & set(["PCA", "UMAP", "TSNE"]),
                do_entropy="entropy" in fig_names,
            )
            for fig_name, fig in figs.items():
                if fig_name in fig_names:
                    all_figs[fig_name][name] = fig
            plt.close("all")
        else:
            st.write(f"Cannot show result for {name}: {res}")

    for fig_name, figs in all_figs.items():
        st.subheader(fig_name)
        for meth_name, fig in figs.items():
            st.write(meth_name)
            if isinstance(fig, Image.Image):
                st.image(fig, use_column_width=True)
            else:
                st.pyplot(fig)

    logging.info(f"See results with prefix {output_file_prefix}")
    run_state.text(f"Done, see results in {output_file_prefix}.")
Exemplo n.º 4
0
def configure_dataset(default_dir, on_sidebar=True):
    stmod = st.sidebar if on_sidebar else st
    stmod.header("Dataset")
    json_files = glob.glob(f"{default_dir}/*.json")
    all_params_files = [(f, xp.load_json_dict(f)) for f in json_files]
    toy_files = [
        f for (f, p) in all_params_files if p.get("dataset_group", "none") == "toy"
    ]
    dataset = stmod.selectbox("Dataset", toy_files, index=0)
    default_params = xp.load_json_dict(dataset)
    if on_sidebar:
        return default_params

    toy_params = deepcopy(default_params)

    # centers position
    default_centers = np.array([[-0.5, 0.0], [0.5, 0]])
    param_centers = default_params["cluster"].get("centers", default_centers.tolist())
    new_centers_st = stmod.text_input(
        "Position of class centers (source)", param_centers
    )
    new_centers = json.loads(new_centers_st)
    n_clusters = len(new_centers)
    stmod.markdown(f"{n_clusters} classes.")
    toy_params["cluster"]["centers"] = new_centers
    toy_params["cluster"]["n_clusters"] = n_clusters

    # centers radii
    radius0 = default_params["cluster"]["radius"]
    same_radius = stmod.checkbox(
        "Use same variance everywhere (class/dimension)",
        value=isinstance(radius0, float),
    )
    if same_radius:
        if not isinstance(radius0, float):
            radius0 = np.array(radius0).flatten()[0]
        radius = stmod.number_input(
            "Class variance",
            step=10 ** (np.floor(np.log10(radius0))),
            value=radius0,
            format="%.4f",
        )
        toy_params["cluster"]["radius"] = radius
    else:
        if isinstance(radius0, float):
            radii = (np.ones_like(new_centers) * radius0).tolist()
        else:
            radii = radius0
        new_radius_st = stmod.text_input(
            "Variance of each class along each dimension", radii
        )
        new_radius = json.loads(new_radius_st)
        shape_variance = np.array(new_radius).shape
        shape_clusters = np.array(new_centers).shape
        if shape_variance == shape_clusters:
            stmod.markdown(
                ":heavy_check_mark: Shape of variance values matches the shape of clusters."
            )
        else:
            stmod.markdown(
                ":warning: Warning: Shape of variances doesn't match the shape of clusters."
            )

        toy_params["cluster"]["radius"] = new_radius

    # class balance
    proba_classes = default_params["cluster"]["proba_classes"]
    if n_clusters == 2:
        new_proba_classes = stmod.number_input(
            "Probability of class 1",
            step=10 ** (np.floor(np.log10(proba_classes))),
            value=proba_classes,
            format="%.4f",
        )
    else:
        new_proba_classes_st = stmod.text_input(
            "Weight or probability of each class (will be normalized to sum to 1)",
            proba_classes,
        )
        new_proba_classes = json.loads(new_proba_classes_st)
        nb_probas = len(new_proba_classes)
        if nb_probas == n_clusters:
            stmod.markdown(
                ":heavy_check_mark: class probas values matches the number of clusters."
            )
        else:
            stmod.markdown(
                ":warning: Warning: class probas values don't match the number of clusters."
            )
    toy_params["cluster"]["proba_classes"] = new_proba_classes

    # target shift
    default_cond_shift = default_params["shift"]["data_shift"]
    if n_clusters == 2:
        cond_shift = stmod.checkbox(
            "Class-conditional shift", value="cond" in default_cond_shift
        )
    else:
        cond_shift = False

    if cond_shift:
        rotation0 = default_params["shift"]["re"]
        if isinstance(rotation0, float):
            default_r0 = rotation0
            default_r1 = rotation0
        else:
            default_r0 = default_params["shift"]["re"][0]
            default_r1 = default_params["shift"]["re"][1]
        re0 = stmod.slider(
            "Rotation class 0", min_value=-np.pi, max_value=np.pi, value=default_r0,
        )
        re1 = stmod.slider(
            "Rotation class 1", min_value=-np.pi, max_value=np.pi, value=default_r1,
        )
        transl0 = default_params["shift"]["te"]
        if isinstance(transl0, float):
            default_t0 = transl0
            default_t1 = transl0
        else:
            default_t0 = default_params["shift"]["te"][0]
            default_t1 = default_params["shift"]["te"][1]
        te0 = stmod.slider(
            "Translation class 0", min_value=-3.0, max_value=3.0, value=default_t0,
        )
        te1 = stmod.slider(
            "Translation class 1", min_value=-3.0, max_value=3.0, value=default_t1,
        )
        toy_params["shift"]["re"] = [re0, re1]
        toy_params["shift"]["te"] = [te0, te1]
    else:
        re = stmod.slider(
            "Rotation",
            min_value=-np.pi,
            max_value=np.pi,
            value=default_params["shift"]["re"],
        )
        te = stmod.slider(
            "Translation",
            min_value=-3.0,
            max_value=3.0,
            value=default_params["shift"]["te"],
        )
        toy_params["shift"]["re"] = re
        toy_params["shift"]["te"] = te

    test_view_data(toy_params)

    # choose a new (unique) name for the dataset and save
    data_hash = xp.param_to_hash(toy_params)
    default_hash = xp.param_to_hash(default_params)
    default_name = toy_params["dataset_name"]
    if default_hash != data_hash:
        if toy_params["dataset_name"] == default_params["dataset_name"]:
            default_name = data_hash

    data_name = st.text_input("Choose a (unique) name for your dataset", default_name)
    data_name = data_name.replace(" ", "_")
    toy_params["dataset_name"] = data_name
    data_file = os.path.join(default_dir, f"{data_name}.json")
    if os.path.exists(data_file):
        st.text(f"Data set with this name exists! {data_file}")
    else:
        if st.button("Save dataset"):
            with open(data_file, "w") as fd:
                fd.write(json.dumps(toy_params))
                default_params = deepcopy(toy_params)
            st.text(f"Configuration saved to {data_file}")

    return toy_params, data_hash
Exemplo n.º 5
0
 def get_data_hash(self):
     return xp.param_to_hash(self._data_config)