Exemple #1
0
def calculatepafdistancebounds(config,
                               shuffle=0,
                               trainingsetindex=0,
                               modelprefix="",
                               numdigits=0,
                               onlytrain=False):
    """
    Returns distances along paf edges in train/test data

    ----------
    config : string
        Full path of the config.yaml file as a string.

    shuffle: integer
        integers specifying shuffle index of the training dataset. The default is 0.

    trainingsetindex: int, optional
        Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). This
        variable can also be set to "all".

    numdigits: number of digits to round for distances.

    """
    import os
    from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal
    from deeplabcut.pose_estimation_tensorflow.config import load_config

    # Read file path for pose_config file. >> pass it on
    cfg = auxiliaryfunctions.read_config(config)

    if cfg["multianimalproject"] == True:
        (
            individuals,
            uniquebodyparts,
            multianimalbodyparts,
        ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)

        # Loading human annotatated data
        trainingsetfolder = auxiliaryfunctions.GetTrainingSetFolder(cfg)
        trainFraction = cfg["TrainingFraction"][trainingsetindex]
        datafn, metadatafn = auxiliaryfunctions.GetDataandMetaDataFilenames(
            trainingsetfolder, trainFraction, shuffle, cfg)
        modelfolder = os.path.join(
            cfg["project_path"],
            str(
                auxiliaryfunctions.GetModelFolder(trainFraction,
                                                  shuffle,
                                                  cfg,
                                                  modelprefix=modelprefix)),
        )

        # Load meta data & annotations
        (
            data,
            trainIndices,
            testIndices,
            trainFraction,
        ) = auxiliaryfunctions.LoadMetadata(
            os.path.join(cfg["project_path"], metadatafn))
        Data = pd.read_hdf(
            os.path.join(
                cfg["project_path"],
                str(trainingsetfolder),
                "CollectedData_" + cfg["scorer"] + ".h5",
            ),
            "df_with_missing",
        )[cfg["scorer"]]

        path_test_config = Path(modelfolder) / "test" / "pose_cfg.yaml"
        dlc_cfg = load_config(str(path_test_config))

        # get the graph!
        partaffinityfield_graph = dlc_cfg.partaffinityfield_graph
        jointnames = [
            dlc_cfg.all_joints_names[i] for i in range(len(dlc_cfg.all_joints))
        ]

        # plt.figure()
        Cutoff = {}

        path_inferencebounds_config = (Path(modelfolder) / "test" /
                                       "inferencebounds.yaml")
        inferenceboundscfg = {}
        for pi, edge in enumerate(partaffinityfield_graph):
            j1, j2 = jointnames[edge[0]], jointnames[edge[1]]
            ds_within = []
            ds_across = []
            for ind in individuals:
                for ind2 in individuals:
                    if ind != "single" and ind2 != "single":
                        if (ind, j1, "x") in Data.keys() and (
                                ind2,
                                j2,
                                "y",
                        ) in Data.keys():
                            distances = np.sqrt(
                                (Data[ind, j1, "x"] - Data[ind2, j1, "x"])**2 +
                                (Data[ind, j1, "y"] - Data[ind2, j2, "y"])**2)
                        else:
                            distances = None

                        if distances is not None:
                            if onlytrain:  # extracts only distances on training data.
                                distances = distances.iloc[trainIndices]

                            # source=np.array(Data[ind,j1,'x'][jj],Data[ind,j1,'y'][jj])
                            # target=np.array(Data[ind2,j1,'x'][jj],Data[ind2,j2,'y'][jj])
                            if ind == ind2:
                                ds_within.extend(distances.values.flatten())
                            else:
                                ds_across.extend(distances.values.flatten())

            edgeencoding = str(edge[0]) + "_" + str(edge[1])
            inferenceboundscfg[edgeencoding] = {}
            if len(ds_within) > 0:
                inferenceboundscfg[edgeencoding]["intra_max"] = str(
                    round(np.nanmax(ds_within), numdigits))
                inferenceboundscfg[edgeencoding]["intra_min"] = str(
                    round(np.nanmin(ds_within), numdigits))
            else:
                inferenceboundscfg[edgeencoding]["intra_max"] = str(
                    1e5)  # large number (larger than any image diameter)
                inferenceboundscfg[edgeencoding]["intra_min"] = str(1e5)

            # NOTE: the inter-animal distances are currently not used, but are interesting to compare to intra_*
            if len(ds_across) > 0:
                inferenceboundscfg[edgeencoding]["inter_max"] = str(
                    round(np.nanmax(ds_across), numdigits))
                inferenceboundscfg[edgeencoding]["inter_min"] = str(
                    round(np.nanmin(ds_across), numdigits))
            else:
                inferenceboundscfg[edgeencoding]["inter_max"] = str(
                    1e5
                )  # large number (larger than image diameters in typical experiments)
                inferenceboundscfg[edgeencoding]["inter_min"] = str(1e5)

            # print(inferenceboundscfg)
            # plt.subplot(len(partaffinityfield_graph),1,pi+1)
            # plt.hist(ds_within,bins=np.linspace(0,100,21),color='red')
            # plt.hist(ds_across,bins=np.linspace(0,100,21),color='blue')

        auxiliaryfunctions.write_plainconfig(str(path_inferencebounds_config),
                                             dict(inferenceboundscfg))
        return inferenceboundscfg
    else:
        print("You might as well bring owls to Athens.")
        return {}
def merge_annotateddatasets(cfg, trainingsetfolder_full, windows2linux):
    """
    Merges all the h5 files for all labeled-datasets (from individual videos).

    This is a bit of a mess because of cross platform compatibility.

    Within platform comp. is straightforward. But if someone labels on windows and wants to train on a unix cluster or colab...
    """
    AnnotationData = []
    data_path = Path(os.path.join(cfg["project_path"], "labeled-data"))
    videos = cfg["video_sets"].keys()
    for video in videos:
        _, filename, _ = _robust_path_split(video)
        file_path = os.path.join(data_path / filename,
                                 f'CollectedData_{cfg["scorer"]}.h5')
        try:
            data = pd.read_hdf(file_path, "df_with_missing")
            AnnotationData.append(data)
        except FileNotFoundError:
            print(
                file_path,
                " not found (perhaps not annotated). If training on cropped data, "
                "make sure to call `cropimagesandlabels` prior to creating the dataset.",
            )

    if not len(AnnotationData):
        print(
            "Annotation data was not found by splitting video paths (from config['video_sets']). An alternative route is taken..."
        )
        AnnotationData = conversioncode.merge_windowsannotationdataONlinuxsystem(
            cfg)
        if not len(AnnotationData):
            print("No data was found!")
            return

    AnnotationData = pd.concat(AnnotationData).sort_index()
    # When concatenating DataFrames with misaligned column labels,
    # all sorts of reordering may happen (mainly depending on 'sort' and 'join')
    # Ensure the 'bodyparts' level agrees with the order in the config file.
    if cfg.get("multianimalproject", False):
        (
            _,
            uniquebodyparts,
            multianimalbodyparts,
        ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
        bodyparts = multianimalbodyparts + uniquebodyparts
    else:
        bodyparts = cfg["bodyparts"]
    AnnotationData = AnnotationData.reindex(
        bodyparts,
        axis=1,
        level=AnnotationData.columns.names.index("bodyparts"))

    # Let's check if the code is *not* run on windows (Source: #https://stackoverflow.com/questions/1325581/how-do-i-check-if-im-running-on-windows-in-python)
    # but the paths are in windows format...
    windowspath = "\\" in AnnotationData.index[0]
    if os.name != "nt" and windowspath and not windows2linux:
        print(
            "It appears that the images were labeled on a Windows system, but you are currently trying to create a training set on a Unix system. \n In this case the paths should be converted. Do you want to proceed with the conversion?"
        )
        askuser = input("yes/no")
    else:
        askuser = "******"

    filename = os.path.join(trainingsetfolder_full,
                            f'CollectedData_{cfg["scorer"]}')
    if (windows2linux or askuser == "yes" or askuser == "y" or askuser
            == "Ja"):  # convert windows path in pandas array \\ to unix / !
        AnnotationData = conversioncode.convertpaths_to_unixstyle(
            AnnotationData, filename)
        print("Annotation data converted to unix format...")
    else:  # store as is
        AnnotationData.to_hdf(filename + ".h5",
                              key="df_with_missing",
                              mode="w")
        AnnotationData.to_csv(filename + ".csv")  # human readable.

    return AnnotationData
    )
    print("Config edited.")

    print("Extracting frames...")
    deeplabcut.extract_frames(config_path, mode="automatic", userfeedback=False)
    print("Frames extracted.")

    print("Creating artificial data...")
    rel_folder = os.path.join("labeled-data", os.path.splitext(video)[0])
    image_folder = os.path.join(cfg["project_path"], rel_folder)
    n_animals = len(cfg["individuals"])
    (
        animals,
        bodyparts_single,
        bodyparts_multi,
    ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
    animals_id = [i for i in range(n_animals) for _ in bodyparts_multi] + [
        n_animals
    ] * len(bodyparts_single)
    map_ = dict(zip(range(len(animals)), animals))
    individuals = [map_[ind] for ind in animals_id for _ in range(2)]
    scorer = [SCORER] * len(individuals)
    coords = ["x", "y"] * len(animals_id)
    bodyparts = [
        bp for _ in range(n_animals) for bp in bodyparts_multi for _ in range(2)
    ]
    bodyparts += [bp for bp in bodyparts_single for _ in range(2)]
    columns = pd.MultiIndex.from_arrays(
        [scorer, individuals, bodyparts, coords],
        names=["scorer", "individuals", "bodyparts", "coords"],
    )
Exemple #4
0
def create_multianimaltraining_dataset(
    config,
    num_shuffles=1,
    Shuffles=None,
    windows2linux=False,
    net_type=None,
    numdigits=2,
    paf_graph=None,
):
    """
    Creates a training dataset for multi-animal datasets. Labels from all the extracted frames are merged into a single .h5 file.\n
    Only the videos included in the config file are used to create this dataset.\n
    [OPTIONAL] Use the function 'add_new_video' at any stage of the project to add more videos to the project.

    Imporant differences to standard:
     - stores coordinates with numdigits as many digits
     - creates
    Parameter
    ----------
    config : string
        Full path of the config.yaml file as a string.

    num_shuffles : int, optional
        Number of shuffles of training dataset to create, i.e. [1,2,3] for num_shuffles=3. Default is set to 1.

    Shuffles: list of shuffles.
        Alternatively the user can also give a list of shuffles (integers!).

    windows2linux: bool.
        The annotation files contain path formated according to your operating system. If you label on windows
        but train & evaluate on a unix system (e.g. ubunt, colab, Mac) set this variable to True to convert the paths.

    net_type: string
        Type of networks. Currently resnet_50, resnet_101, and resnet_152, efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
        efficientnet-b4, efficientnet-b5, and efficientnet-b6 as well as dlcrnet_ms5 are supported (not the MobileNets!).
        See Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1

    numdigits: int, optional

    paf_graph: list of lists, optional (default=None)
        If not None, overwrite the default complete graph. This is useful for advanced users who
        already know a good graph, or simply want to use a specific one. Note that, in that case,
        the data-driven selection procedure upon model evaluation will be skipped.

    Example
    --------
    >>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1)

    Windows:
    >>> deeplabcut.create_multianimaltraining_dataset(r'C:\\Users\\Ulf\\looming-task\\config.yaml',Shuffles=[3,17,5])
    --------
    """

    # Loading metadata from config file:
    cfg = auxiliaryfunctions.read_config(config)
    scorer = cfg["scorer"]
    project_path = cfg["project_path"]
    # Create path for training sets & store data there
    trainingsetfolder = auxiliaryfunctions.GetTrainingSetFolder(cfg)
    full_training_path = Path(project_path, trainingsetfolder)
    auxiliaryfunctions.attempttomakefolder(full_training_path, recursive=True)

    Data = merge_annotateddatasets(cfg, full_training_path, windows2linux)
    if Data is None:
        return
    Data = Data[scorer]

    def strip_cropped_image_name(path):
        # utility function to split different crops from same image into either train or test!
        head, filename = os.path.split(path)
        if cfg["croppedtraining"]:
            filename = filename.split("c")[0]
        return os.path.join(head, filename)

    img_names = Data.index.map(strip_cropped_image_name).unique()

    if net_type is None:  # loading & linking pretrained models
        net_type = cfg.get("default_net_type", "dlcrnet_ms5")
    elif not any(net in net_type for net in ("resnet", "eff", "dlc")):
        raise ValueError(f"Unsupported network {net_type}.")

    multi_stage = False
    if net_type == "dlcrnet_ms5":
        net_type = "resnet_50"
        multi_stage = True

    dataset_type = "multi-animal-imgaug"
    (
        individuals,
        uniquebodyparts,
        multianimalbodyparts,
    ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)

    if paf_graph is None:  # Automatically form a complete PAF graph
        partaffinityfield_graph = [
            list(edge)
            for edge in combinations(range(len(multianimalbodyparts)), 2)
        ]
    else:
        # Ignore possible connections between 'multi' and 'unique' body parts;
        # one can never be too careful...
        to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
            cfg, paf_graph)
        partaffinityfield_graph = [
            edge for i, edge in enumerate(paf_graph) if i not in to_ignore
        ]
        auxfun_multianimal.validate_paf_graph(cfg, partaffinityfield_graph)

    print("Utilizing the following graph:", partaffinityfield_graph)
    partaffinityfield_predict = True

    # Loading the encoder (if necessary downloading from TF)
    dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
    defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
    model_path, num_shuffles = auxfun_models.Check4weights(
        net_type, Path(dlcparent_path), num_shuffles)

    if Shuffles is None:
        Shuffles = range(1, num_shuffles + 1, 1)
    else:
        Shuffles = [i for i in Shuffles if isinstance(i, int)]

    TrainingFraction = cfg["TrainingFraction"]
    for shuffle in Shuffles:  # Creating shuffles starting from 1
        for trainFraction in TrainingFraction:
            train_inds_temp, test_inds_temp = SplitTrials(
                range(len(img_names)), trainFraction)
            # Map back to the original indices.
            temp = [
                re.escape(name) for i, name in enumerate(img_names)
                if i in test_inds_temp
            ]
            mask = Data.index.str.contains("|".join(temp))
            testIndices = np.flatnonzero(mask)
            trainIndices = np.flatnonzero(~mask)

            ####################################################
            # Generating data structure with labeled information & frame metadata (for deep cut)
            ####################################################
            print(
                "Creating training data for: Shuffle:",
                shuffle,
                "TrainFraction: ",
                trainFraction,
            )

            # Make training file!
            data = format_multianimal_training_data(
                Data,
                trainIndices,
                cfg["project_path"],
                numdigits,
            )

            if len(trainIndices) > 0:
                (
                    datafilename,
                    metadatafilename,
                ) = auxiliaryfunctions.GetDataandMetaDataFilenames(
                    trainingsetfolder, trainFraction, shuffle, cfg)
                ################################################################################
                # Saving metadata and data file (Pickle file)
                ################################################################################
                auxiliaryfunctions.SaveMetadata(
                    os.path.join(project_path, metadatafilename),
                    data,
                    trainIndices,
                    testIndices,
                    trainFraction,
                )

                datafilename = datafilename.split(".mat")[0] + ".pickle"
                import pickle

                with open(os.path.join(project_path, datafilename), "wb") as f:
                    # Pickle the 'labeled-data' dictionary using the highest protocol available.
                    pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

                ################################################################################
                # Creating file structure for training &
                # Test files as well as pose_yaml files (containing training and testing information)
                #################################################################################

                modelfoldername = auxiliaryfunctions.GetModelFolder(
                    trainFraction, shuffle, cfg)
                auxiliaryfunctions.attempttomakefolder(
                    Path(config).parents[0] / modelfoldername, recursive=True)
                auxiliaryfunctions.attempttomakefolder(
                    str(Path(config).parents[0] / modelfoldername / "train"))
                auxiliaryfunctions.attempttomakefolder(
                    str(Path(config).parents[0] / modelfoldername / "test"))

                path_train_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "train",
                        "pose_cfg.yaml",
                    ))
                path_test_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "test",
                        "pose_cfg.yaml",
                    ))
                path_inference_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "test",
                        "inference_cfg.yaml",
                    ))

                jointnames = [str(bpt) for bpt in multianimalbodyparts]
                jointnames.extend([str(bpt) for bpt in uniquebodyparts])
                items2change = {
                    "dataset":
                    datafilename,
                    "metadataset":
                    metadatafilename,
                    "num_joints":
                    len(multianimalbodyparts) +
                    len(uniquebodyparts),  # cfg["uniquebodyparts"]),
                    "all_joints": [[i] for i in range(
                        len(multianimalbodyparts) + len(uniquebodyparts))
                                   ],  # cfg["uniquebodyparts"]))],
                    "all_joints_names":
                    jointnames,
                    "init_weights":
                    model_path,
                    "project_path":
                    str(cfg["project_path"]),
                    "net_type":
                    net_type,
                    "multi_stage":
                    multi_stage,
                    "pairwise_loss_weight":
                    0.1,
                    "pafwidth":
                    20,
                    "partaffinityfield_graph":
                    partaffinityfield_graph,
                    "partaffinityfield_predict":
                    partaffinityfield_predict,
                    "weigh_only_present_joints":
                    False,
                    "num_limbs":
                    len(partaffinityfield_graph),
                    "dataset_type":
                    dataset_type,
                    "optimizer":
                    "adam",
                    "batch_size":
                    8,
                    "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000],
                                   [1e-5, 200000]],
                    "save_iters":
                    10000,
                    "display_iters":
                    500,
                    "num_idchannel":
                    len(cfg["individuals"])
                    if cfg.get("identity", False) else 0,
                }

                trainingdata = MakeTrain_pose_yaml(items2change,
                                                   path_train_config,
                                                   defaultconfigfile)
                keys2save = [
                    "dataset",
                    "num_joints",
                    "all_joints",
                    "all_joints_names",
                    "net_type",
                    "multi_stage",
                    "init_weights",
                    "global_scale",
                    "location_refinement",
                    "locref_stdev",
                    "dataset_type",
                    "partaffinityfield_predict",
                    "pairwise_predict",
                    "partaffinityfield_graph",
                    "num_limbs",
                    "dataset_type",
                    "num_idchannel",
                ]

                MakeTest_pose_yaml(
                    trainingdata,
                    keys2save,
                    path_test_config,
                    nmsradius=5.0,
                    minconfidence=0.01,
                )  # setting important def. values for inference

                # Setting inference cfg file:
                defaultinference_configfile = os.path.join(
                    dlcparent_path, "inference_cfg.yaml")
                items2change = {
                    "minimalnumberofconnections":
                    int(len(cfg["multianimalbodyparts"]) / 2),
                    "topktoretain":
                    len(cfg["individuals"]) + 1 *
                    (len(cfg["uniquebodyparts"]) > 0),
                    "withid":
                    cfg.get("identity", False),
                }
                MakeInference_yaml(items2change, path_inference_config,
                                   defaultinference_configfile)

                print(
                    "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!"
                )
            else:
                pass
def create_multianimaltraining_dataset(
    config,
    num_shuffles=1,
    Shuffles=None,
    windows2linux=False,
    net_type=None,
    numdigits=2,
    crop_size=(400, 400),
    crop_sampling="hybrid",
    paf_graph=None,
    trainIndices=None,
    testIndices=None,
    n_edges_threshold=105,
    paf_graph_degree=6,
):
    """
    Creates a training dataset for multi-animal datasets. Labels from all the extracted frames are merged into a single .h5 file.\n
    Only the videos included in the config file are used to create this dataset.\n
    [OPTIONAL] Use the function 'add_new_video' at any stage of the project to add more videos to the project.

    Imporant differences to standard:
     - stores coordinates with numdigits as many digits
     - creates
    Parameter
    ----------
    config : string
        Full path of the config.yaml file as a string.

    num_shuffles : int, optional
        Number of shuffles of training dataset to create, i.e. [1,2,3] for num_shuffles=3. Default is set to 1.

    Shuffles: list of shuffles.
        Alternatively the user can also give a list of shuffles (integers!).

    net_type: string
        Type of networks. Currently resnet_50, resnet_101, and resnet_152, efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
        efficientnet-b4, efficientnet-b5, and efficientnet-b6 as well as dlcrnet_ms5 are supported (not the MobileNets!).
        See Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1

    numdigits: int, optional

    crop_size: tuple of int, optional
        Dimensions (width, height) of the crops for data augmentation.
        Default is 400x400.

    crop_sampling: str, optional
        Crop centers sampling method. Must be either:
        "uniform" (randomly over the image),
        "keypoints" (randomly over the annotated keypoints),
        "density" (weighing preferentially dense regions of keypoints),
        or "hybrid" (alternating randomly between "uniform" and "density").
        Default is "hybrid".

    paf_graph: list of lists, or "config" optional (default=None)
        If not None, overwrite the default complete graph. This is useful for advanced users who
        already know a good graph, or simply want to use a specific one. Note that, in that case,
        the data-driven selection procedure upon model evaluation will be skipped.

        "config" will use the skeleton defined in the config file.

    trainIndices: list of lists, optional (default=None)
        List of one or multiple lists containing train indexes.
        A list containing two lists of training indexes will produce two splits.

    testIndices: list of lists, optional (default=None)
        List of one or multiple lists containing test indexes.

    n_edges_threshold: int, optional (default=105)
        Number of edges above which the graph is automatically pruned.

    paf_graph_degree: int, optional (default=6)
        Degree of paf_graph when automatically pruning it (before training).
        
    Example
    --------
    >>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1)

    >>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml', Shuffles=[0,1,2], trainIndices=[trainInd1, trainInd2, trainInd3], testIndices=[testInd1, testInd2, testInd3])

    Windows:
    >>> deeplabcut.create_multianimaltraining_dataset(r'C:\\Users\\Ulf\\looming-task\\config.yaml',Shuffles=[3,17,5])
    --------
    """
    if windows2linux:
        warnings.warn(
            "`windows2linux` has no effect since 2.2.0.4 and will be removed in 2.2.1.",
            FutureWarning,
        )

    if len(crop_size) != 2 or not all(isinstance(v, int) for v in crop_size):
        raise ValueError(
            "Crop size must be a tuple of two integers (width, height).")

    if crop_sampling not in ("uniform", "keypoints", "density", "hybrid"):
        raise ValueError(
            f"Invalid sampling {crop_sampling}. Must be "
            f"either 'uniform', 'keypoints', 'density', or 'hybrid.")

    # Loading metadata from config file:
    cfg = auxiliaryfunctions.read_config(config)
    scorer = cfg["scorer"]
    project_path = cfg["project_path"]
    # Create path for training sets & store data there
    trainingsetfolder = auxiliaryfunctions.GetTrainingSetFolder(cfg)
    full_training_path = Path(project_path, trainingsetfolder)
    auxiliaryfunctions.attempttomakefolder(full_training_path, recursive=True)

    Data = merge_annotateddatasets(cfg, full_training_path)
    if Data is None:
        return
    Data = Data[scorer]

    if net_type is None:  # loading & linking pretrained models
        net_type = cfg.get("default_net_type", "dlcrnet_ms5")
    elif not any(net in net_type for net in ("resnet", "eff", "dlc", "mob")):
        raise ValueError(f"Unsupported network {net_type}.")

    multi_stage = False
    ### dlcnet_ms5: backbone resnet50 + multi-fusion & multi-stage module
    ### dlcr101_ms5/dlcr152_ms5: backbone resnet101/152 + multi-fusion & multi-stage module
    if all(net in net_type for net in ("dlcr", "_ms5")):
        num_layers = re.findall("dlcr([0-9]*)", net_type)[0]
        if num_layers == "":
            num_layers = 50
        net_type = "resnet_{}".format(num_layers)
        multi_stage = True

    dataset_type = "multi-animal-imgaug"
    (
        individuals,
        uniquebodyparts,
        multianimalbodyparts,
    ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)

    if paf_graph is None:  # Automatically form a complete PAF graph
        n_bpts = len(multianimalbodyparts)
        partaffinityfield_graph = [
            list(edge) for edge in combinations(range(n_bpts), 2)
        ]
        n_edges_orig = len(partaffinityfield_graph)
        # If the graph is unnecessarily large (with 15+ keypoints by default),
        # we randomly prune it to a size guaranteeing an average node degree of 6;
        # see Suppl. Fig S9c in Lauer et al., 2022.
        if n_edges_orig >= n_edges_threshold:
            partaffinityfield_graph = auxfun_multianimal.prune_paf_graph(
                partaffinityfield_graph,
                average_degree=paf_graph_degree,
            )
    else:
        if paf_graph == "config":
            # Use the skeleton defined in the config file
            skeleton = cfg["skeleton"]
            paf_graph = [
                sorted((multianimalbodyparts.index(bpt1),
                        multianimalbodyparts.index(bpt2)))
                for bpt1, bpt2 in skeleton
            ]
            print(
                "Using `skeleton` from the config file as a paf_graph. Data-driven skeleton will not be computed."
            )

        # Ignore possible connections between 'multi' and 'unique' body parts;
        # one can never be too careful...
        to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
            cfg, paf_graph)
        partaffinityfield_graph = [
            edge for i, edge in enumerate(paf_graph) if i not in to_ignore
        ]
        auxfun_multianimal.validate_paf_graph(cfg, partaffinityfield_graph)

    print("Utilizing the following graph:", partaffinityfield_graph)
    # Disable the prediction of PAFs if the graph is empty
    partaffinityfield_predict = bool(partaffinityfield_graph)

    # Loading the encoder (if necessary downloading from TF)
    dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
    defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
    model_path, num_shuffles = auxfun_models.Check4weights(
        net_type, Path(dlcparent_path), num_shuffles)

    if Shuffles is None:
        Shuffles = range(1, num_shuffles + 1, 1)
    else:
        Shuffles = [i for i in Shuffles if isinstance(i, int)]

    # print(trainIndices,testIndices, Shuffles, augmenter_type,net_type)
    if trainIndices is None and testIndices is None:
        splits = []
        for shuffle in Shuffles:  # Creating shuffles starting from 1
            for train_frac in cfg["TrainingFraction"]:
                train_inds, test_inds = SplitTrials(range(len(Data)),
                                                    train_frac)
                splits.append((train_frac, shuffle, (train_inds, test_inds)))
    else:
        if len(trainIndices) != len(testIndices) != len(Shuffles):
            raise ValueError(
                "Number of Shuffles and train and test indexes should be equal."
            )
        splits = []
        for shuffle, (train_inds,
                      test_inds) in enumerate(zip(trainIndices, testIndices)):
            trainFraction = round(
                len(train_inds) * 1.0 / (len(train_inds) + len(test_inds)), 2)
            print(
                f"You passed a split with the following fraction: {int(100 * trainFraction)}%"
            )
            # Now that the training fraction is guaranteed to be correct,
            # the values added to pad the indices are removed.
            train_inds = np.asarray(train_inds)
            train_inds = train_inds[train_inds != -1]
            test_inds = np.asarray(test_inds)
            test_inds = test_inds[test_inds != -1]
            splits.append(
                (trainFraction, Shuffles[shuffle], (train_inds, test_inds)))

    for trainFraction, shuffle, (trainIndices, testIndices) in splits:
        ####################################################
        # Generating data structure with labeled information & frame metadata (for deep cut)
        ####################################################
        print(
            "Creating training data for: Shuffle:",
            shuffle,
            "TrainFraction: ",
            trainFraction,
        )

        # Make training file!
        data = format_multianimal_training_data(
            Data,
            trainIndices,
            cfg["project_path"],
            numdigits,
        )

        if len(trainIndices) > 0:
            (
                datafilename,
                metadatafilename,
            ) = auxiliaryfunctions.GetDataandMetaDataFilenames(
                trainingsetfolder, trainFraction, shuffle, cfg)
            ################################################################################
            # Saving metadata and data file (Pickle file)
            ################################################################################
            auxiliaryfunctions.SaveMetadata(
                os.path.join(project_path, metadatafilename),
                data,
                trainIndices,
                testIndices,
                trainFraction,
            )

            datafilename = datafilename.split(".mat")[0] + ".pickle"
            import pickle

            with open(os.path.join(project_path, datafilename), "wb") as f:
                # Pickle the 'labeled-data' dictionary using the highest protocol available.
                pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

            ################################################################################
            # Creating file structure for training &
            # Test files as well as pose_yaml files (containing training and testing information)
            #################################################################################

            modelfoldername = auxiliaryfunctions.GetModelFolder(
                trainFraction, shuffle, cfg)
            auxiliaryfunctions.attempttomakefolder(Path(config).parents[0] /
                                                   modelfoldername,
                                                   recursive=True)
            auxiliaryfunctions.attempttomakefolder(
                str(Path(config).parents[0] / modelfoldername / "train"))
            auxiliaryfunctions.attempttomakefolder(
                str(Path(config).parents[0] / modelfoldername / "test"))

            path_train_config = str(
                os.path.join(
                    cfg["project_path"],
                    Path(modelfoldername),
                    "train",
                    "pose_cfg.yaml",
                ))
            path_test_config = str(
                os.path.join(
                    cfg["project_path"],
                    Path(modelfoldername),
                    "test",
                    "pose_cfg.yaml",
                ))
            path_inference_config = str(
                os.path.join(
                    cfg["project_path"],
                    Path(modelfoldername),
                    "test",
                    "inference_cfg.yaml",
                ))

            jointnames = [str(bpt) for bpt in multianimalbodyparts]
            jointnames.extend([str(bpt) for bpt in uniquebodyparts])
            items2change = {
                "dataset":
                datafilename,
                "metadataset":
                metadatafilename,
                "num_joints":
                len(multianimalbodyparts) +
                len(uniquebodyparts),  # cfg["uniquebodyparts"]),
                "all_joints": [[i] for i in range(
                    len(multianimalbodyparts) + len(uniquebodyparts))
                               ],  # cfg["uniquebodyparts"]))],
                "all_joints_names":
                jointnames,
                "init_weights":
                model_path,
                "project_path":
                str(cfg["project_path"]),
                "net_type":
                net_type,
                "multi_stage":
                multi_stage,
                "pairwise_loss_weight":
                0.1,
                "pafwidth":
                20,
                "partaffinityfield_graph":
                partaffinityfield_graph,
                "partaffinityfield_predict":
                partaffinityfield_predict,
                "weigh_only_present_joints":
                False,
                "num_limbs":
                len(partaffinityfield_graph),
                "dataset_type":
                dataset_type,
                "optimizer":
                "adam",
                "batch_size":
                8,
                "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5,
                                                                 200000]],
                "save_iters":
                10000,
                "display_iters":
                500,
                "num_idchannel":
                len(cfg["individuals"]) if cfg.get("identity", False) else 0,
                "crop_size":
                list(crop_size),
                "crop_sampling":
                crop_sampling,
            }

            trainingdata = MakeTrain_pose_yaml(items2change, path_train_config,
                                               defaultconfigfile)
            keys2save = [
                "dataset",
                "num_joints",
                "all_joints",
                "all_joints_names",
                "net_type",
                "multi_stage",
                "init_weights",
                "global_scale",
                "location_refinement",
                "locref_stdev",
                "dataset_type",
                "partaffinityfield_predict",
                "pairwise_predict",
                "partaffinityfield_graph",
                "num_limbs",
                "dataset_type",
                "num_idchannel",
            ]

            MakeTest_pose_yaml(
                trainingdata,
                keys2save,
                path_test_config,
                nmsradius=5.0,
                minconfidence=0.01,
                sigma=1,
                locref_smooth=False,
            )  # setting important def. values for inference

            # Setting inference cfg file:
            defaultinference_configfile = os.path.join(dlcparent_path,
                                                       "inference_cfg.yaml")
            items2change = {
                "minimalnumberofconnections":
                int(len(cfg["multianimalbodyparts"]) / 2),
                "topktoretain":
                len(cfg["individuals"]) + 1 *
                (len(cfg["uniquebodyparts"]) > 0),
                "withid":
                cfg.get("identity", False),
            }
            MakeInference_yaml(items2change, path_inference_config,
                               defaultinference_configfile)

            print(
                "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!"
            )
        else:
            pass
    def __init__(self, parent, config, video, shuffle, Dataframe, savelabeled,
                 multianimal):
        super(MainFrame,
              self).__init__("DeepLabCut2.0 - Manual Outlier Frame Extraction",
                             parent)

        ###################################################################################################################################################
        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!
        # topSplitter = wx.SplitterWindow(self)
        #
        # self.image_panel = ImagePanel(topSplitter, config,video,shuffle,Dataframe,self.gui_size)
        # self.widget_panel = WidgetPanel(topSplitter)
        #
        # topSplitter.SplitHorizontally(self.image_panel, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9
        # topSplitter.SetSashGravity(1)
        # sizer = wx.BoxSizer(wx.VERTICAL)
        # sizer.Add(topSplitter, 1, wx.EXPAND)
        # self.SetSizer(sizer)

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, config, self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)

        vSplitter.SplitVertically(self.image_panel,
                                  self.choice_panel,
                                  sashPosition=self.gui_size[0] * 0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter,
                                      self.widget_panel,
                                      sashPosition=self.gui_size[1] *
                                      0.83)  # 0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

        ###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)

        self.load_button_sizer = wx.BoxSizer(wx.VERTICAL)
        self.help_button_sizer = wx.BoxSizer(wx.VERTICAL)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        self.help_button_sizer.Add(self.help, 1, wx.ALL, 15)
        #        widgetsizer.Add(self.help , 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)

        widgetsizer.Add(self.help_button_sizer, 1, wx.ALL, 0)

        self.grab = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="Grab Frames")
        widgetsizer.Add(self.grab, 1, wx.ALL, 15)
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.grab.Enable(True)

        widgetsizer.AddStretchSpacer(5)
        self.slider = wx.Slider(
            self.widget_panel,
            id=wx.ID_ANY,
            value=0,
            minValue=0,
            maxValue=1,
            size=(200, -1),
            style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS,
        )
        widgetsizer.Add(self.slider, 1, wx.ALL, 5)
        self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)

        widgetsizer.AddStretchSpacer(5)
        self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL)
        self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL)

        self.start_frames_sizer.AddSpacer(15)
        #        self.startFrame = wx.SpinCtrl(self.widget_panel, value='0', size=(100, -1), min=0, max=120)
        self.startFrame = wx.SpinCtrl(self.widget_panel,
                                      value="0",
                                      size=(100, -1))  # ,style=wx.SP_VERTICAL)
        self.startFrame.Enable(False)
        self.start_frames_sizer.Add(self.startFrame, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        start_text = wx.StaticText(self.widget_panel,
                                   label="Start Frame Index")
        self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                    15)
        self.checkBox = wx.CheckBox(self.widget_panel,
                                    id=wx.ID_ANY,
                                    label="Range of frames")
        self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range)
        self.start_frames_sizer.Add(self.checkBox, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        #
        self.end_frames_sizer.AddSpacer(15)
        self.endFrame = wx.SpinCtrl(self.widget_panel,
                                    value="1",
                                    size=(160, -1))  # , min=1, max=120)
        self.endFrame.Enable(False)
        self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                  15)
        end_text = wx.StaticText(self.widget_panel, label="Number of Frames")
        self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame = wx.Button(self.widget_panel,
                                     id=wx.ID_ANY,
                                     label="Update")
        self.end_frames_sizer.Add(self.updateFrame, 1,
                                  wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider)
        self.updateFrame.Enable(False)

        widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(5)
        widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(15)

        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit, 1, wx.ALL, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)
        self.quit.Enable(True)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)

        # Variables initialization
        self.numberFrames = 0
        self.currFrame = 0
        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.drs = []
        self.extract_range_frame = False
        self.firstFrame = 0
        self.Colorscheme = []

        # Read confing file
        self.cfg = auxiliaryfunctions.read_config(config)
        self.Task = self.cfg["Task"]
        self.start = self.cfg["start"]
        self.stop = self.cfg["stop"]
        self.date = self.cfg["date"]
        self.trainFraction = self.cfg["TrainingFraction"]
        self.trainFraction = self.trainFraction[0]
        self.videos = self.cfg["video_sets"].keys()
        self.bodyparts = self.cfg["bodyparts"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg["dotsize"]
        self.alpha = self.cfg["alphavalue"]
        self.iterationindex = self.cfg["iteration"]
        self.cropping = self.cfg["cropping"]
        self.video_names = [Path(i).stem for i in self.videos]
        self.config_path = Path(config)
        self.video_source = Path(video).resolve()
        self.shuffle = shuffle
        self.Dataframe = Dataframe
        conversioncode.guarantee_multiindex_rows(self.Dataframe)
        self.savelabeled = savelabeled
        self.multianimal = multianimal
        if self.multianimal:
            from deeplabcut.utils import auxfun_multianimal

            (
                self.individual_names,
                self.uniquebodyparts,
                self.multianimalbodyparts,
            ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
            self.choiceBox, self.visualization_rdb = self.choice_panel.addRadioButtons(
            )
            self.Colorscheme = visualization.get_cmap(
                len(self.individual_names), self.cfg["colormap"])
            self.visualization_rdb.Bind(wx.EVT_RADIOBOX, self.clear_plot)
        # Read the video file
        self.vid = VideoWriter(str(self.video_source))
        if self.cropping:
            self.vid.set_bbox(self.cfg["x1"], self.cfg["x2"], self.cfg["y1"],
                              self.cfg["y2"])
        self.filename = Path(self.video_source).name
        self.numberFrames = len(self.vid)
        self.strwidth = int(np.ceil(np.log10(self.numberFrames)))
        # Set the values of slider and range of frames
        self.startFrame.SetMax(self.numberFrames - 1)
        self.slider.SetMax(self.numberFrames - 1)
        self.endFrame.SetMax(self.numberFrames - 1)
        self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider)  # wx.EVT_SPIN
        # Set the status bar
        self.statusbar.SetStatusText("Working on video: {}".format(
            self.filename))
        # Adding the video file to the config file.
        if self.vid.name not in self.video_names:
            add.add_new_videos(self.config_path, [self.video_source])

        self.update()
        self.plot_labels()
        self.widget_panel.Layout()
Exemple #7
0
    def browseDir(self, event):
        """
        Show the DirDialog and ask the user to change the directory where machine labels are stored
        """
        self.statusbar.SetStatusText(
            "Looking for a folder to start labeling...")
        cwd = os.path.join(os.getcwd(), "labeled-data")
        dlg = wx.DirDialog(
            self,
            "Choose the directory where your extracted frames are saved:",
            cwd,
            style=wx.DD_DEFAULT_STYLE,
        )
        if dlg.ShowModal() == wx.ID_OK:
            self.dir = dlg.GetPath()
            self.load.Enable(False)
            self.next.Enable(True)
            self.save.Enable(True)
        else:
            dlg.Destroy()
            self.Close(True)
            return
        dlg.Destroy()

        # Enabling the zoom, pan and home buttons
        self.zoom.Enable(True)
        self.home.Enable(True)
        self.pan.Enable(True)
        self.lock.Enable(True)

        # Reading config file and its variables
        self.cfg = auxiliaryfunctions.read_config(self.config_file)
        self.scorer = self.cfg["scorer"]
        (
            individuals,
            uniquebodyparts,
            multianimalbodyparts,
        ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)

        self.multibodyparts = multianimalbodyparts
        self.uniquebodyparts = uniquebodyparts
        self.individual_names = individuals

        self.videos = self.cfg["video_sets"].keys()
        self.markerSize = self.cfg["dotsize"]
        self.edgewidth = self.markerSize // 3
        self.alpha = self.cfg["alphavalue"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.idmap = plt.cm.get_cmap("Set1", len(individuals))
        self.project_path = self.cfg["project_path"]

        if self.uniquebodyparts == []:
            self.are_unique_bodyparts_present = False

        self.buttonCounter = {i: [] for i in self.individual_names}
        self.index = np.sort([
            fn for fn in glob.glob(os.path.join(self.dir, "*.png"))
            if ("labeled.png" not in fn)
        ])
        self.statusbar.SetStatusText("Working on folder: {}".format(
            os.path.split(str(self.dir))[-1]))
        self.relativeimagenames = [
            "labeled" + n.split("labeled")[1] for n in self.index
        ]  # [n.split(self.project_path+'/')[1] for n in self.index]

        # Reading the existing dataset,if already present
        try:
            self.dataFrame = pd.read_hdf(
                os.path.join(self.dir, "CollectedData_" + self.scorer + ".h5"),
                "df_with_missing",
            )
            self.dataFrame.sort_index(inplace=True)
            self.prev.Enable(True)
            # Finds the first empty row in the dataframe and sets the iteration to that index
            for idx, j in enumerate(self.dataFrame.index):
                values = self.dataFrame.loc[j, :].values
                if np.prod(np.isnan(values)) == 1:
                    self.iter = idx
                    break
                else:
                    self.iter = idx
        except:
            # Create an empty data frame
            self.dataFrame = MainFrame.create_dataframe(
                self,
                self.dataFrame,
                self.relativeimagenames,
                self.individual_names,
                self.uniquebodyparts,
                self.multibodyparts,
            )
            self.iter = 0

        # Reading the image name
        self.img = self.index[self.iter]
        img_name = Path(self.index[self.iter]).name

        # Checking for new frames and adding them to the existing dataframe
        old_imgs = np.sort(list(self.dataFrame.index))
        self.newimages = list(set(self.relativeimagenames) - set(old_imgs))
        if self.newimages == []:
            pass
        else:
            print("Found new frames..")
            # Create an empty dataframe with all the new images and then merge this to the existing dataframe.
            self.df = MainFrame.create_dataframe(
                self,
                None,
                self.newimages,
                self.individual_names,
                self.uniquebodyparts,
                self.multibodyparts,
            )
            self.dataFrame = pd.concat([self.dataFrame, self.df], axis=0)
            # Sort it by the index values
            self.dataFrame.sort_index(inplace=True)

        # checks for unique bodyparts
        if len(self.multibodyparts) != len(set(self.multibodyparts)):
            print(
                "Error - bodyparts must have unique labels! Please choose unique bodyparts in config.yaml file and try again. Quitting for now!"
            )
            self.Close(True)

        # Extracting the list of old labels
        oldMultiBodyParts = self.dataFrame.iloc[:,
                                                self.dataFrame.columns.
                                                get_level_values(1) == [
                                                    i for i in self.
                                                    individual_names
                                                    if i != "single"
                                                ][0],
                                                ].columns.get_level_values(2)
        _, idx = np.unique(oldMultiBodyParts, return_index=True)
        oldMultiBodyparts2plot = list(oldMultiBodyParts[np.sort(idx)])
        # Extracting the list of old unique labels
        if self.are_unique_bodyparts_present:
            oldUniqueBodyParts = self.dataFrame.iloc[:,
                                                     self.dataFrame.columns.
                                                     get_level_values(1) ==
                                                     "single"].columns.get_level_values(
                                                         2)
            _, idx_unique = np.unique(oldUniqueBodyParts, return_index=True)
            oldUniqueBodyparts2plot = list(
                oldUniqueBodyParts[np.sort(idx_unique)])
            self.new_UniqueBodyparts = [
                x for x in self.uniquebodyparts
                if x not in oldUniqueBodyparts2plot
            ]
        else:
            self.new_UniqueBodyparts = []
        # Checking if new labels are added or not
        self.new_Multibodyparts = [
            x for x in self.multibodyparts if x not in oldMultiBodyparts2plot
        ]

        # Checking if user added a new label
        if (self.new_Multibodyparts == []
                and self.new_UniqueBodyparts == []):  # i.e. no new labels
            (
                self.figure,
                self.axes,
                self.canvas,
                self.toolbar,
                self.image_axis,
            ) = self.image_panel.drawplot(
                self.img,
                img_name,
                self.iter,
                self.index,
                self.multibodyparts,
                self.colormap,
                keep_view=self.view_locked,
            )
        else:
            # Found new labels in either multiple bodyparts or unique bodyparts
            dlg = wx.MessageDialog(
                None,
                "New label found in the config file. Do you want to see all the other labels?",
                "New label found",
                wx.YES_NO | wx.ICON_WARNING,
            )
            result = dlg.ShowModal()
            if result == wx.ID_NO:
                if self.new_Multibodyparts != []:
                    self.multibodyparts = self.new_Multibodyparts
                if self.new_UniqueBodyparts != []:
                    self.uniquebodyparts = self.new_UniqueBodyparts

            self.dataFrame = MainFrame.create_dataframe(
                self,
                self.dataFrame,
                self.relativeimagenames,
                self.individual_names,
                self.new_UniqueBodyparts,
                self.new_Multibodyparts,
            )
            (
                self.figure,
                self.axes,
                self.canvas,
                self.toolbar,
                self.image_axis,
            ) = self.image_panel.drawplot(
                self.img,
                img_name,
                self.iter,
                self.index,
                self.multibodyparts,
                self.colormap,
                keep_view=self.view_locked,
            )

        self.axes.callbacks.connect("xlim_changed", self.onZoom)
        self.axes.callbacks.connect("ylim_changed", self.onZoom)

        if self.individual_names[0] == "single":
            (
                self.choiceBox,
                self.individualrdb,
                self.rdb,
                self.change_marker_size,
                self.checkBox,
            ) = self.choice_panel.addRadioButtons(self.uniquebodyparts,
                                                  self.individual_names,
                                                  self.file, self.markerSize)
            self.image_panel.addcolorbar(
                self.img,
                self.image_axis,
                self.iter,
                self.uniquebodyparts,
                self.colormap,
            )
        else:
            (
                self.choiceBox,
                self.individualrdb,
                self.rdb,
                self.change_marker_size,
                self.checkBox,
            ) = self.choice_panel.addRadioButtons(self.multibodyparts,
                                                  self.individual_names,
                                                  self.file, self.markerSize)
            self.image_panel.addcolorbar(self.img, self.image_axis, self.iter,
                                         self.multibodyparts, self.colormap)
        self.individualrdb.Bind(wx.EVT_RADIOBOX, self.select_individual)
        # check if single is slected when radio buttons are changed
        if self.rdb.GetStringSelection == "single":
            self.norm, self.colorIndex = self.image_panel.getColorIndices(
                self.img, self.uniquebodyparts)
        else:
            self.norm, self.colorIndex = self.image_panel.getColorIndices(
                self.img, self.multibodyparts)
        self.buttonCounter = MainFrame.plot(self, self.img)
        self.cidClick = self.canvas.mpl_connect("button_press_event",
                                                self.onClick)
        self.cidClick = self.canvas.mpl_connect("button_press_event",
                                                self.onClick)

        self.checkBox.Bind(wx.EVT_CHECKBOX, self.activateSlider)
        self.change_marker_size.Bind(wx.EVT_SLIDER, self.OnSliderScroll)
    def __init__(self, parent, config, video, shuffle, Dataframe, savelabeled,
                 multianimal):
        # Settting the GUI size and panels design
        displays = (wx.Display(i) for i in range(wx.Display.GetCount())
                    )  # Gets the number of displays
        screenSizes = [
            display.GetGeometry().GetSize() for display in displays
        ]  # Gets the size of each display
        index = 0  # For display 1.
        screenWidth = screenSizes[index][0]
        screenHeight = screenSizes[index][1]
        self.gui_size = (screenWidth * 0.7, screenHeight * 0.85)

        wx.Frame.__init__(
            self,
            parent,
            id=wx.ID_ANY,
            title="DeepLabCut2.0 - Manual Outlier Frame Extraction",
            size=wx.Size(self.gui_size),
            pos=wx.DefaultPosition,
            style=wx.RESIZE_BORDER | wx.DEFAULT_FRAME_STYLE | wx.TAB_TRAVERSAL,
        )
        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetStatusText("")

        self.SetSizeHints(
            wx.Size(self.gui_size)
        )  #  This sets the minimum size of the GUI. It can scale now!

        ###################################################################################################################################################
        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!
        # topSplitter = wx.SplitterWindow(self)
        #
        # self.image_panel = ImagePanel(topSplitter, config,video,shuffle,Dataframe,self.gui_size)
        # self.widget_panel = WidgetPanel(topSplitter)
        #
        # topSplitter.SplitHorizontally(self.image_panel, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9
        # topSplitter.SetSashGravity(1)
        # sizer = wx.BoxSizer(wx.VERTICAL)
        # sizer.Add(topSplitter, 1, wx.EXPAND)
        # self.SetSizer(sizer)

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)

        vSplitter.SplitVertically(self.image_panel,
                                  self.choice_panel,
                                  sashPosition=self.gui_size[0] * 0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter,
                                      self.widget_panel,
                                      sashPosition=self.gui_size[1] *
                                      0.83)  # 0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

        ###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)

        self.load_button_sizer = wx.BoxSizer(wx.VERTICAL)
        self.help_button_sizer = wx.BoxSizer(wx.VERTICAL)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        self.help_button_sizer.Add(self.help, 1, wx.ALL, 15)
        #        widgetsizer.Add(self.help , 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)

        widgetsizer.Add(self.help_button_sizer, 1, wx.ALL, 0)

        self.grab = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="Grab Frames")
        widgetsizer.Add(self.grab, 1, wx.ALL, 15)
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.grab.Enable(True)

        widgetsizer.AddStretchSpacer(5)
        self.slider = wx.Slider(
            self.widget_panel,
            id=wx.ID_ANY,
            value=0,
            minValue=0,
            maxValue=1,
            size=(200, -1),
            style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS,
        )
        widgetsizer.Add(self.slider, 1, wx.ALL, 5)
        self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)

        widgetsizer.AddStretchSpacer(5)
        self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL)
        self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL)

        self.start_frames_sizer.AddSpacer(15)
        #        self.startFrame = wx.SpinCtrl(self.widget_panel, value='0', size=(100, -1), min=0, max=120)
        self.startFrame = wx.SpinCtrl(self.widget_panel,
                                      value="0",
                                      size=(100, -1))  # ,style=wx.SP_VERTICAL)
        self.startFrame.Enable(False)
        self.start_frames_sizer.Add(self.startFrame, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        start_text = wx.StaticText(self.widget_panel,
                                   label="Start Frame Index")
        self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                    15)
        self.checkBox = wx.CheckBox(self.widget_panel,
                                    id=wx.ID_ANY,
                                    label="Range of frames")
        self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range)
        self.start_frames_sizer.Add(self.checkBox, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        #
        self.end_frames_sizer.AddSpacer(15)
        self.endFrame = wx.SpinCtrl(self.widget_panel,
                                    value="1",
                                    size=(160, -1))  # , min=1, max=120)
        self.endFrame.Enable(False)
        self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                  15)
        end_text = wx.StaticText(self.widget_panel, label="Number of Frames")
        self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame = wx.Button(self.widget_panel,
                                     id=wx.ID_ANY,
                                     label="Update")
        self.end_frames_sizer.Add(self.updateFrame, 1,
                                  wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider)
        self.updateFrame.Enable(False)

        widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(5)
        widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(15)

        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit, 1, wx.ALL, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)
        self.quit.Enable(True)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)

        # Variables initialization
        self.numberFrames = 0
        self.currFrame = 0
        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.drs = []
        self.extract_range_frame = False
        self.firstFrame = 0
        self.Colorscheme = []
        # self.cropping = False

        # Read confing file
        self.cfg = auxiliaryfunctions.read_config(config)
        self.Task = self.cfg["Task"]
        self.start = self.cfg["start"]
        self.stop = self.cfg["stop"]
        self.date = self.cfg["date"]
        self.trainFraction = self.cfg["TrainingFraction"]
        self.trainFraction = self.trainFraction[0]
        self.videos = self.cfg["video_sets"].keys()
        self.bodyparts = self.cfg["bodyparts"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg["dotsize"]
        self.alpha = self.cfg["alphavalue"]
        self.iterationindex = self.cfg["iteration"]
        self.cropping = self.cfg["cropping"]
        self.video_names = [Path(i).stem for i in self.videos]
        self.config_path = Path(config)
        self.video_source = Path(video).resolve()
        self.shuffle = shuffle
        self.Dataframe = Dataframe
        self.savelabeled = savelabeled
        self.multianimal = multianimal
        if self.multianimal:
            from deeplabcut.utils import auxfun_multianimal

            (
                self.individual_names,
                self.uniquebodyparts,
                self.multianimalbodyparts,
            ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
            self.choiceBox, self.visualization_rdb = self.choice_panel.addRadioButtons(
            )
            self.Colorscheme = visualization.get_cmap(
                len(self.individual_names), self.cfg["colormap"])
            self.visualization_rdb.Bind(wx.EVT_RADIOBOX, self.clear_plot)
        # Read the video file
        self.vid = cv2.VideoCapture(str(self.video_source))
        self.videoPath = os.path.dirname(self.video_source)
        self.filename = Path(self.video_source).name
        self.numberFrames = int(self.vid.get(cv2.CAP_PROP_FRAME_COUNT))
        self.strwidth = int(np.ceil(np.log10(self.numberFrames)))
        # Set the values of slider and range of frames
        self.startFrame.SetMax(self.numberFrames - 1)
        self.slider.SetMax(self.numberFrames - 1)
        self.endFrame.SetMax(self.numberFrames - 1)
        self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider)  # wx.EVT_SPIN
        # Set the status bar
        self.statusbar.SetStatusText("Working on video: {}".format(
            os.path.split(str(self.video_source))[-1]))
        # Adding the video file to the config file.
        if not (str(self.video_source.stem) in self.video_names):
            add.add_new_videos(self.config_path, [self.video_source])

        self.filename = Path(self.video_source).name
        self.update()
        self.plot_labels()
        self.widget_panel.Layout()
Exemple #9
0
def _benchmark_paf_graphs(
    config,
    inference_cfg,
    data,
    paf_inds,
    greedy=False,
    add_discarded=True,
    identity_only=False,
    calibration_file="",
    oks_sigma=0.1,
):
    n_multi = len(auxfun_multianimal.extractindividualsandbodyparts(config)[2])
    data_ = {"metadata": data.pop("metadata")}
    for k, v in data.items():
        data_[k] = v["prediction"]
    ass = Assembler(
        data_,
        max_n_individuals=inference_cfg["topktoretain"],
        n_multibodyparts=n_multi,
        greedy=greedy,
        pcutoff=inference_cfg.get("pcutoff", 0.1),
        min_affinity=inference_cfg.get("pafthreshold", 0.1),
        add_discarded=add_discarded,
        identity_only=identity_only,
    )
    if calibration_file:
        ass.calibrate(calibration_file)

    params = ass.metadata
    image_paths = params["imnames"]
    bodyparts = params["joint_names"]
    idx = (data[image_paths[0]]["groundtruth"][2].unstack("coords").reindex(
        bodyparts, level="bodyparts").index)
    individuals = idx.get_level_values("individuals").unique()
    n_individuals = len(individuals)
    map_ = dict(zip(individuals, range(n_individuals)))

    # Form ground truth beforehand
    ground_truth = []
    for i, imname in enumerate(image_paths):
        temp = data[imname]["groundtruth"][2]
        ground_truth.append(temp.to_numpy().reshape((-1, 2)))
    ground_truth = np.stack(ground_truth)
    temp = np.ones((*ground_truth.shape[:2], 3))
    temp[..., :2] = ground_truth
    temp = temp.reshape((temp.shape[0], n_individuals, -1, 3))
    ass_true_dict = _parse_ground_truth_data(temp)
    ids = np.vectorize(map_.get)(
        idx.get_level_values("individuals").to_numpy())
    ground_truth = np.insert(ground_truth, 2, ids, axis=2)

    # Assemble animals on the full set of detections
    paf_inds = sorted(paf_inds, key=len)
    paf_graph = ass.graph
    n_graphs = len(paf_inds)
    all_scores = []
    all_metrics = []
    for j, paf in enumerate(paf_inds, start=1):
        print(f"Graph {j}|{n_graphs}")
        ass.paf_inds = paf
        ass.assemble()
        oks = evaluate_assembly(ass.assemblies, ass_true_dict, oks_sigma)
        all_metrics.append(oks)
        scores = np.full((len(image_paths), 2), np.nan)
        for i, imname in enumerate(tqdm(image_paths)):
            gt = ground_truth[i]
            gt = gt[~np.isnan(gt).any(axis=1)]
            if len(np.unique(
                    gt[:, 2])) < 2:  # Only consider frames with 2+ animals
                continue

            # Count the number of unassembled bodyparts
            n_dets = len(gt)
            animals = ass.assemblies.get(i)
            if animals is None:
                if n_dets:
                    scores[i, 0] = 1
            else:
                animals = [
                    np.c_[animal.data,
                          np.ones(animal.data.shape[0]) * n]
                    for n, animal in enumerate(animals)
                ]
                hyp = np.concatenate(animals)
                hyp = hyp[~np.isnan(hyp).any(axis=1)]
                scores[i, 0] = (n_dets - hyp.shape[0]) / n_dets
                neighbors = _find_closest_neighbors(gt[:, :2], hyp[:, :2])
                valid = neighbors != -1
                id_gt = gt[valid, 2]
                id_hyp = hyp[neighbors[valid], -1]
                mat = contingency_matrix(id_gt, id_hyp)
                purity = mat.max(axis=0).sum() / mat.sum()
                scores[i, 1] = purity
        all_scores.append((scores, paf))

    dfs = []
    for score, inds in all_scores:
        df = pd.DataFrame(score, columns=["miss", "purity"])
        df["ngraph"] = len(inds)
        dfs.append(df)
    big_df = pd.concat(dfs)
    group = big_df.groupby("ngraph")
    return (all_scores, group.agg(["mean", "std"]).T, all_metrics)
    def __init__(self, parent, config):
        super(MainFrame, self).__init__(
            "DeepLabCut - Refinement ToolBox",
            parent,
        )
        self.Bind(wx.EVT_CHAR_HOOK, self.OnKeyPressed)

        ###################################################################################################################################################

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, config, self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)
        # self.choice_panel.SetupScrolling(scroll_x=True, scroll_y=True, scrollToTop=False)
        # self.choice_panel.SetupScrolling(scroll_x=True, scrollToTop=False)
        vSplitter.SplitVertically(self.image_panel,
                                  self.choice_panel,
                                  sashPosition=self.gui_size[0] * 0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter,
                                      self.widget_panel,
                                      sashPosition=self.gui_size[1] *
                                      0.83)  # 0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

        ###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)
        self.load = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="Load labels")
        widgetsizer.Add(self.load, 1, wx.ALL, 15)
        self.load.Bind(wx.EVT_BUTTON, self.browseDir)

        self.prev = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="<<Previous")
        widgetsizer.Add(self.prev, 1, wx.ALL, 15)
        self.prev.Bind(wx.EVT_BUTTON, self.prevImage)
        self.prev.Enable(False)

        self.next = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Next>>")
        widgetsizer.Add(self.next, 1, wx.ALL, 15)
        self.next.Bind(wx.EVT_BUTTON, self.nextImage)
        self.next.Enable(False)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        widgetsizer.Add(self.help, 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)
        self.help.Enable(True)

        self.zoom = wx.ToggleButton(self.widget_panel, label="Zoom")
        widgetsizer.Add(self.zoom, 1, wx.ALL, 15)
        self.zoom.Bind(wx.EVT_TOGGLEBUTTON, self.zoomButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.zoom.Enable(False)

        self.home = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Home")
        widgetsizer.Add(self.home, 1, wx.ALL, 15)
        self.home.Bind(wx.EVT_BUTTON, self.homeButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.home.Enable(False)

        self.pan = wx.ToggleButton(self.widget_panel,
                                   id=wx.ID_ANY,
                                   label="Pan")
        widgetsizer.Add(self.pan, 1, wx.ALL, 15)
        self.pan.Bind(wx.EVT_TOGGLEBUTTON, self.panButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.pan.Enable(False)

        self.lock = wx.CheckBox(self.widget_panel,
                                id=wx.ID_ANY,
                                label="Lock View")
        widgetsizer.Add(self.lock, 1, wx.ALL, 15)
        self.lock.Bind(wx.EVT_CHECKBOX, self.lockChecked)
        self.widget_panel.SetSizer(widgetsizer)
        self.lock.Enable(False)

        self.save = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Save")
        widgetsizer.Add(self.save, 1, wx.ALL, 15)
        self.save.Bind(wx.EVT_BUTTON, self.saveDataSet)
        self.save.Enable(False)

        widgetsizer.AddStretchSpacer(15)
        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit, 1, wx.ALL, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)
        self.widget_panel.Layout()

        ###############################################################################################################################
        # Variable initialization
        self.currentDirectory = os.getcwd()
        self.index = []
        self.iter = []
        self.threshold = []
        self.file = 0
        self.updatedCoords = []
        self.drs = []
        self.cfg = auxiliaryfunctions.read_config(config)
        self.humanscorer = self.cfg["scorer"]
        self.move2corner = self.cfg["move2corner"]
        self.center = self.cfg["corner2move2"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg["dotsize"]
        self.alpha = self.cfg["alphavalue"]
        self.iterationindex = self.cfg["iteration"]
        self.project_path = self.cfg["project_path"]
        self.bodyparts = self.cfg["bodyparts"]
        self.threshold = 0.1
        self.img_size = (10, 6)  # (imgW, imgH)  # width, height in inches.
        self.preview = False
        self.view_locked = False
        # Workaround for MAC - xlim and ylim changed events seem to be triggered too often so need to make sure that the
        # xlim and ylim have actually changed before turning zoom off
        self.prezoom_xlim = []
        self.prezoom_ylim = []
        from deeplabcut.utils import auxfun_multianimal

        (
            self.individual_names,
            self.uniquebodyparts,
            self.multianimalbodyparts,
        ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
        # self.choiceBox,self.visualization_rdb = self.choice_panel.addRadioButtons()
        self.Colorscheme = visualization.get_cmap(len(self.individual_names),
                                                  self.cfg["colormap"])
def evaluate_multianimal_full(
    config,
    Shuffles=[1],
    trainingsetindex=0,
    plotting=None,
    show_errors=True,
    comparisonbodyparts="all",
    gputouse=None,
    modelprefix="",
    c_engine=False,
):
    """
    WIP multi animal project.
    """

    import os

    from deeplabcut.pose_estimation_tensorflow.nnet import predict
    from deeplabcut.pose_estimation_tensorflow.nnet import (
        predict_multianimal as predictma, )
    from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal
    from deeplabcut.pose_estimation_tensorflow.dataset.pose_dataset import data_to_input

    import tensorflow as tf

    if "TF_CUDNN_USE_AUTOTUNE" in os.environ:
        del os.environ[
            "TF_CUDNN_USE_AUTOTUNE"]  # was potentially set during training

    tf.reset_default_graph()
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  #
    if gputouse is not None:  # gpu selectinon
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gputouse)

    start_path = os.getcwd()

    ##################################################
    # Load data...
    ##################################################
    cfg = auxiliaryfunctions.read_config(config)
    if trainingsetindex == "all":
        TrainingFractions = cfg["TrainingFraction"]
    else:
        TrainingFractions = [cfg["TrainingFraction"][trainingsetindex]]

    # Loading human annotatated data
    trainingsetfolder = auxiliaryfunctions.GetTrainingSetFolder(cfg)
    Data = pd.read_hdf(
        os.path.join(
            cfg["project_path"],
            str(trainingsetfolder),
            "CollectedData_" + cfg["scorer"] + ".h5",
        ),
        "df_with_missing",
    )
    # Get list of body parts to evaluate network for
    comparisonbodyparts = auxiliaryfunctions.IntersectionofBodyPartsandOnesGivenbyUser(
        cfg, comparisonbodyparts)
    colors = visualization.get_cmap(len(comparisonbodyparts),
                                    name=cfg["colormap"])
    # Make folder for evaluation
    auxiliaryfunctions.attempttomakefolder(
        str(cfg["project_path"] + "/evaluation-results/"))
    for shuffle in Shuffles:
        for trainFraction in TrainingFractions:
            ##################################################
            # Load and setup CNN part detector
            ##################################################
            datafn, metadatafn = auxiliaryfunctions.GetDataandMetaDataFilenames(
                trainingsetfolder, trainFraction, shuffle, cfg)
            modelfolder = os.path.join(
                cfg["project_path"],
                str(
                    auxiliaryfunctions.GetModelFolder(
                        trainFraction, shuffle, cfg, modelprefix=modelprefix)),
            )
            path_test_config = Path(modelfolder) / "test" / "pose_cfg.yaml"

            # Load meta data
            (
                data,
                trainIndices,
                testIndices,
                trainFraction,
            ) = auxiliaryfunctions.LoadMetadata(
                os.path.join(cfg["project_path"], metadatafn))

            try:
                dlc_cfg = load_config(str(path_test_config))
            except FileNotFoundError:
                raise FileNotFoundError(
                    "It seems the model for shuffle %s and trainFraction %s does not exist."
                    % (shuffle, trainFraction))

            # TODO: IMPLEMENT for different batch sizes?
            dlc_cfg["batch_size"] = 1  # due to differently sized images!!!

            # Create folder structure to store results.
            evaluationfolder = os.path.join(
                cfg["project_path"],
                str(
                    auxiliaryfunctions.GetEvaluationFolder(
                        trainFraction, shuffle, cfg, modelprefix=modelprefix)),
            )
            auxiliaryfunctions.attempttomakefolder(evaluationfolder,
                                                   recursive=True)
            # path_train_config = modelfolder / 'train' / 'pose_cfg.yaml'

            # Check which snapshots are available and sort them by # iterations
            Snapshots = np.array([
                fn.split(".")[0]
                for fn in os.listdir(os.path.join(str(modelfolder), "train"))
                if "index" in fn
            ])
            if len(Snapshots) == 0:
                print(
                    "Snapshots not found! It seems the dataset for shuffle %s and trainFraction %s is not trained.\nPlease train it before evaluating.\nUse the function 'train_network' to do so."
                    % (shuffle, trainFraction))
            else:
                increasing_indices = np.argsort(
                    [int(m.split("-")[1]) for m in Snapshots])
                Snapshots = Snapshots[increasing_indices]

                if cfg["snapshotindex"] == -1:
                    snapindices = [-1]
                elif cfg["snapshotindex"] == "all":
                    snapindices = range(len(Snapshots))
                elif cfg["snapshotindex"] < len(Snapshots):
                    snapindices = [cfg["snapshotindex"]]
                else:
                    print(
                        "Invalid choice, only -1 (last), any integer up to last, or all (as string)!"
                    )

                (
                    individuals,
                    uniquebodyparts,
                    multianimalbodyparts,
                ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)

                final_result = []
                ##################################################
                # Compute predictions over images
                ##################################################
                for snapindex in snapindices:
                    dlc_cfg["init_weights"] = os.path.join(
                        str(modelfolder), "train", Snapshots[snapindex]
                    )  # setting weights to corresponding snapshot.
                    trainingsiterations = (
                        dlc_cfg["init_weights"].split(os.sep)[-1]
                    ).split(
                        "-"
                    )[-1]  # read how many training siterations that corresponds to.

                    # name for deeplabcut net (based on its parameters)
                    DLCscorer, DLCscorerlegacy = auxiliaryfunctions.GetScorerName(
                        cfg,
                        shuffle,
                        trainFraction,
                        trainingsiterations,
                        modelprefix=modelprefix,
                    )
                    print(
                        "Running ",
                        DLCscorer,
                        " with # of trainingiterations:",
                        trainingsiterations,
                    )
                    (
                        notanalyzed,
                        resultsfilename,
                        DLCscorer,
                    ) = auxiliaryfunctions.CheckifNotEvaluated(
                        str(evaluationfolder),
                        DLCscorer,
                        DLCscorerlegacy,
                        Snapshots[snapindex],
                    )

                    if os.path.isfile(
                            resultsfilename.split(".h5")[0] + "_full.pickle"):
                        print("Model already evaluated.", resultsfilename)
                    else:
                        if plotting:
                            foldername = os.path.join(
                                str(evaluationfolder),
                                "LabeledImages_" + DLCscorer + "_" +
                                Snapshots[snapindex],
                            )
                            auxiliaryfunctions.attempttomakefolder(foldername)

                        # print(dlc_cfg)
                        # Specifying state of model (snapshot / training state)
                        sess, inputs, outputs = predict.setup_pose_prediction(
                            dlc_cfg)

                        PredicteData = {}
                        print("Analyzing data...")
                        for imageindex, imagename in tqdm(enumerate(
                                Data.index)):
                            image_path = os.path.join(cfg["project_path"],
                                                      imagename)
                            image = io.imread(image_path)
                            frame = img_as_ubyte(skimage.color.gray2rgb(image))

                            GT = Data.iloc[imageindex]

                            # Storing GT data as dictionary, so it can be used for calculating connection costs
                            groundtruthcoordinates = []
                            groundtruthidentity = []
                            for bptindex, bpt in enumerate(
                                    dlc_cfg["all_joints_names"]):
                                coords = np.zeros([len(individuals), 2
                                                   ]) * np.nan
                                identity = []
                                for prfxindex, prefix in enumerate(
                                        individuals):
                                    if bpt in uniquebodyparts and prefix == "single":
                                        coords[prfxindex, :] = np.array([
                                            GT[cfg["scorer"]][prefix][bpt]
                                            ["x"],
                                            GT[cfg["scorer"]][prefix][bpt]
                                            ["y"],
                                        ])
                                        identity.append(prefix)
                                    elif (bpt in multianimalbodyparts
                                          and prefix != "single"):
                                        coords[prfxindex, :] = np.array([
                                            GT[cfg["scorer"]][prefix][bpt]
                                            ["x"],
                                            GT[cfg["scorer"]][prefix][bpt]
                                            ["y"],
                                        ])
                                        identity.append(prefix)
                                    else:
                                        identity.append("nix")

                                groundtruthcoordinates.append(
                                    coords[np.isfinite(coords[:, 0]), :])
                                groundtruthidentity.append(
                                    np.array(identity)[np.isfinite(coords[:,
                                                                          0])])

                            PredicteData[imagename] = {}
                            PredicteData[imagename]["index"] = imageindex

                            pred = predictma.get_detectionswithcostsandGT(
                                frame,
                                groundtruthcoordinates,
                                dlc_cfg,
                                sess,
                                inputs,
                                outputs,
                                outall=False,
                                nms_radius=dlc_cfg.nmsradius,
                                det_min_score=dlc_cfg.minconfidence,
                                c_engine=c_engine,
                            )
                            PredicteData[imagename]["prediction"] = pred
                            PredicteData[imagename]["groundtruth"] = [
                                groundtruthidentity,
                                groundtruthcoordinates,
                                GT,
                            ]

                            if plotting:
                                coords_pred = pred["coordinates"][0]
                                probs_pred = pred["confidence"]
                                fig = visualization.make_multianimal_labeled_image(
                                    frame,
                                    groundtruthcoordinates,
                                    coords_pred,
                                    probs_pred,
                                    colors,
                                    cfg["dotsize"],
                                    cfg["alphavalue"],
                                    cfg["pcutoff"],
                                )

                                visualization.save_labeled_frame(
                                    fig,
                                    image_path,
                                    foldername,
                                    imageindex in trainIndices,
                                )

                        sess.close()  # closes the current tf session
                        PredicteData["metadata"] = {
                            "nms radius":
                            dlc_cfg.nmsradius,
                            "minimal confidence":
                            dlc_cfg.minconfidence,
                            "PAFgraph":
                            dlc_cfg.partaffinityfield_graph,
                            "all_joints":
                            [[i] for i in range(len(dlc_cfg.all_joints))],
                            "all_joints_names": [
                                dlc_cfg.all_joints_names[i]
                                for i in range(len(dlc_cfg.all_joints))
                            ],
                            "stride":
                            dlc_cfg.get("stride", 8),
                        }
                        print(
                            "Done and results stored for snapshot: ",
                            Snapshots[snapindex],
                        )

                        dictionary = {
                            "Scorer": DLCscorer,
                            "DLC-model-config file": dlc_cfg,
                            "trainIndices": trainIndices,
                            "testIndices": testIndices,
                            "trainFraction": trainFraction,
                        }
                        metadata = {"data": dictionary}
                        auxfun_multianimal.SaveFullMultiAnimalData(
                            PredicteData, metadata, resultsfilename)

                        tf.reset_default_graph()

    # returning to intial folder
    os.chdir(str(start_path))
def create_multianimaltraining_dataset(
    config,
    num_shuffles=1,
    Shuffles=None,
    windows2linux=False,
    net_type=None,
    numdigits=2,
):
    """
    Creates a training dataset for multi-animal datasets. Labels from all the extracted frames are merged into a single .h5 file.\n
    Only the videos included in the config file are used to create this dataset.\n
    [OPTIONAL] Use the function 'add_new_video' at any stage of the project to add more videos to the project.

    Imporant differences to standard:
     - stores coordinates with numdigits as many digits
     - creates
    Parameter
    ----------
    config : string
        Full path of the config.yaml file as a string.

    num_shuffles : int, optional
        Number of shuffles of training dataset to create, i.e. [1,2,3] for num_shuffles=3. Default is set to 1.

    Shuffles: list of shuffles.
        Alternatively the user can also give a list of shuffles (integers!).

    windows2linux: bool.
        The annotation files contain path formated according to your operating system. If you label on windows
        but train & evaluate on a unix system (e.g. ubunt, colab, Mac) set this variable to True to convert the paths.

    net_type: string
        Type of networks. Currently resnet_50, resnet_101, and resnet_152 are supported (not the MobileNets!)

    numdigits: int, optional


    Example
    --------
    >>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1)

    Windows:
    >>> deeplabcut.create_multianimaltraining_dataset(r'C:\\Users\\Ulf\\looming-task\\config.yaml',Shuffles=[3,17,5])
    --------
    """
    from skimage import io

    # Loading metadata from config file:
    cfg = auxiliaryfunctions.read_config(config)
    scorer = cfg["scorer"]
    project_path = cfg["project_path"]
    # Create path for training sets & store data there
    trainingsetfolder = auxiliaryfunctions.GetTrainingSetFolder(
        cfg)  # Path concatenatn OS platform independent
    auxiliaryfunctions.attempttomakefolder(Path(
        os.path.join(project_path, str(trainingsetfolder))),
                                           recursive=True)

    Data = trainingsetmanipulation.merge_annotateddatasets(
        cfg, Path(os.path.join(project_path, trainingsetfolder)),
        windows2linux)
    if Data is None:
        return
    Data = Data[scorer]  # extract labeled data

    # actualbpts=set(Data.columns.get_level_values(0))

    def strip_cropped_image_name(path):
        # utility function to split different crops from same image into either train or test!
        filename = os.path.split(path)[1]
        return filename.split("c")[0]

    img_names = Data.index.map(strip_cropped_image_name).unique()

    # loading & linking pretrained models
    # CURRENTLY ONLY ResNet supported!
    if net_type is None:  # loading & linking pretrained models
        net_type = cfg.get("default_net_type", "resnet_50")
    else:
        if "resnet" in net_type:  # or 'mobilenet' in net_type:
            pass
        else:
            raise ValueError("Currently only resnet is supported.")

    # multianimal case:
    dataset_type = "multi-animal-imgaug"
    partaffinityfield_graph = auxfun_multianimal.getpafgraph(cfg,
                                                             printnames=False)
    # ATTENTION: order has to be multibodyparts, then uniquebodyparts (for indexing)
    print("Utilizing the following graph:", partaffinityfield_graph)
    num_limbs = len(partaffinityfield_graph)
    partaffinityfield_predict = True

    # Loading the encoder (if necessary downloading from TF)
    dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
    defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
    model_path, num_shuffles = auxfun_models.Check4weights(
        net_type, Path(dlcparent_path), num_shuffles)

    if Shuffles == None:
        Shuffles = range(1, num_shuffles + 1, 1)
    else:
        Shuffles = [i for i in Shuffles if isinstance(i, int)]

    (
        individuals,
        uniquebodyparts,
        multianimalbodyparts,
    ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)

    TrainingFraction = cfg["TrainingFraction"]
    for shuffle in Shuffles:  # Creating shuffles starting from 1
        for trainFraction in TrainingFraction:
            train_inds_temp, test_inds_temp = trainingsetmanipulation.SplitTrials(
                range(len(img_names)), trainFraction)
            # Map back to the original indices.
            temp = [
                name for i, name in enumerate(img_names) if i in test_inds_temp
            ]
            mask = Data.index.str.contains("|".join(temp))
            testIndices = np.flatnonzero(mask)
            trainIndices = np.flatnonzero(~mask)

            ####################################################
            # Generating data structure with labeled information & frame metadata (for deep cut)
            ####################################################

            # Make training file!
            data = []
            print("Creating training data for ", shuffle, trainFraction)
            print("This can take some time...")
            for jj in tqdm(trainIndices):
                jointsannotated = False
                H = {}
                # load image to get dimensions:
                filename = Data.index[jj]
                im = io.imread(os.path.join(cfg["project_path"], filename))
                H["image"] = filename

                try:
                    H["size"] = np.array(
                        [np.shape(im)[2],
                         np.shape(im)[0],
                         np.shape(im)[1]])
                except:
                    # print "Grayscale!"
                    H["size"] = np.array([1, np.shape(im)[0], np.shape(im)[1]])

                Joints = {}
                for prfxindex, prefix in enumerate(individuals):
                    joints = (np.zeros(
                        (len(uniquebodyparts) + len(multianimalbodyparts), 3))
                              * np.nan)
                    if prefix != "single":  # first ones are multianimalparts!
                        indexjoints = 0
                        for bpindex, bodypart in enumerate(
                                multianimalbodyparts):
                            socialbdpt = bodypart  # prefix+bodypart #build names!
                            # if socialbdpt in actualbpts:
                            try:
                                x, y = (
                                    Data[prefix][socialbdpt]["x"][jj],
                                    Data[prefix][socialbdpt]["y"][jj],
                                )
                                joints[indexjoints, 0] = int(bpindex)
                                joints[indexjoints, 1] = round(x, numdigits)
                                joints[indexjoints, 2] = round(y, numdigits)
                                indexjoints += 1
                            except:
                                pass
                    else:
                        indexjoints = len(multianimalbodyparts)
                        for bpindex, bodypart in enumerate(uniquebodyparts):
                            socialbdpt = bodypart  # prefix+bodypart #build names!
                            # if socialbdpt in actualbpts:
                            try:
                                x, y = (
                                    Data[prefix][socialbdpt]["x"][jj],
                                    Data[prefix][socialbdpt]["y"][jj],
                                )
                                joints[indexjoints, 0] = len(
                                    multianimalbodyparts) + int(bpindex)
                                joints[indexjoints, 1] = round(x, 2)
                                joints[indexjoints, 2] = round(y, 2)
                                indexjoints += 1
                            except:
                                pass

                    # Drop missing body parts
                    joints = joints[~np.isnan(joints).any(axis=1)]
                    # Drop points lying outside the image
                    inside = np.logical_and.reduce((
                        joints[:, 1] < im.shape[1],
                        joints[:, 1] > 0,
                        joints[:, 2] < im.shape[0],
                        joints[:, 2] > 0,
                    ))
                    joints = joints[inside]

                    if np.size(joints) > 0:  # exclude images without labels
                        jointsannotated = True

                    Joints[prfxindex] = joints  # np.array(joints, dtype=int)

                H["joints"] = Joints
                if jointsannotated:  # exclude images without labels
                    data.append(H)

            if len(trainIndices) > 0:
                (
                    datafilename,
                    metadatafilename,
                ) = auxiliaryfunctions.GetDataandMetaDataFilenames(
                    trainingsetfolder, trainFraction, shuffle, cfg)
                ################################################################################
                # Saving metadata and data file (Pickle file)
                ################################################################################
                auxiliaryfunctions.SaveMetadata(
                    os.path.join(project_path, metadatafilename),
                    data,
                    trainIndices,
                    testIndices,
                    trainFraction,
                )

                datafilename = datafilename.split(".mat")[0] + ".pickle"
                import pickle

                with open(os.path.join(project_path, datafilename), "wb") as f:
                    # Pickle the 'labeled-data' dictionary using the highest protocol available.
                    pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

                ################################################################################
                # Creating file structure for training &
                # Test files as well as pose_yaml files (containing training and testing information)
                #################################################################################

                modelfoldername = auxiliaryfunctions.GetModelFolder(
                    trainFraction, shuffle, cfg)
                auxiliaryfunctions.attempttomakefolder(
                    Path(config).parents[0] / modelfoldername, recursive=True)
                auxiliaryfunctions.attempttomakefolder(
                    str(Path(config).parents[0] / modelfoldername) + "/" +
                    "/train")
                auxiliaryfunctions.attempttomakefolder(
                    str(Path(config).parents[0] / modelfoldername) + "/" +
                    "/test")

                path_train_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "train",
                        "pose_cfg.yaml",
                    ))
                path_test_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "test",
                        "pose_cfg.yaml",
                    ))
                path_inference_config = str(
                    os.path.join(
                        cfg["project_path"],
                        Path(modelfoldername),
                        "test",
                        "inference_cfg.yaml",
                    ))

                jointnames = [str(bpt) for bpt in multianimalbodyparts]
                jointnames.extend([str(bpt) for bpt in uniquebodyparts])
                items2change = {
                    "dataset":
                    datafilename,
                    "metadataset":
                    metadatafilename,
                    "num_joints":
                    len(multianimalbodyparts) +
                    len(uniquebodyparts),  # cfg["uniquebodyparts"]),
                    "all_joints": [[i] for i in range(
                        len(multianimalbodyparts) + len(uniquebodyparts))
                                   ],  # cfg["uniquebodyparts"]))],
                    "all_joints_names":
                    jointnames,
                    "init_weights":
                    model_path,
                    "project_path":
                    str(cfg["project_path"]),
                    "net_type":
                    net_type,
                    "pairwise_loss_weight":
                    0.1,
                    "pafwidth":
                    20,
                    "partaffinityfield_graph":
                    partaffinityfield_graph,
                    "partaffinityfield_predict":
                    partaffinityfield_predict,
                    "weigh_only_present_joints":
                    False,
                    "num_limbs":
                    len(partaffinityfield_graph),
                    "dataset_type":
                    dataset_type,
                    "optimizer":
                    "adam",
                    "batch_size":
                    8,
                    "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000],
                                   [1e-5, 200000]],
                    "save_iters":
                    10000,
                    "display_iters":
                    500,
                }

                defaultconfigfile = os.path.join(dlcparent_path,
                                                 "pose_cfg.yaml")
                trainingdata = trainingsetmanipulation.MakeTrain_pose_yaml(
                    items2change, path_train_config, defaultconfigfile)
                keys2save = [
                    "dataset",
                    "num_joints",
                    "all_joints",
                    "all_joints_names",
                    "net_type",
                    "init_weights",
                    "global_scale",
                    "location_refinement",
                    "locref_stdev",
                    "dataset_type",
                    "partaffinityfield_predict",
                    "pairwise_predict",
                    "partaffinityfield_graph",
                    "num_limbs",
                    "dataset_type",
                ]

                trainingsetmanipulation.MakeTest_pose_yaml(
                    trainingdata,
                    keys2save,
                    path_test_config,
                    nmsradius=5.0,
                    minconfidence=0.01,
                )  # setting important def. values for inference

                # Setting inference cfg file:
                defaultinference_configfile = os.path.join(
                    dlcparent_path, "inference_cfg.yaml")
                items2change = {
                    "minimalnumberofconnections":
                    int(len(cfg["multianimalbodyparts"]) / 2),
                    "topktoretain":
                    len(cfg["individuals"]) + 1 *
                    (len(cfg["uniquebodyparts"]) > 0),
                }
                # TODO:   "distnormalization":  could be calculated here based on data and set
                # >> now we calculate this during evaluation (which is a good spot...)
                trainingsetmanipulation.MakeInference_yaml(
                    items2change, path_inference_config,
                    defaultinference_configfile)

                print(
                    "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!"
                )
            else:
                pass
    def __init__(self, parent,config):
        # Settting the GUI size and panels design
        displays = (wx.Display(i) for i in range(wx.Display.GetCount())) # Gets the number of displays
        screenSizes = [display.GetGeometry().GetSize() for display in displays] # Gets the size of each display
        index = 0 # For display 1.
        screenWidth = screenSizes[index][0]
        screenHeight = screenSizes[index][1]
        self.gui_size = (screenWidth*0.7,screenHeight*0.85)

        wx.Frame.__init__ ( self, parent, id = wx.ID_ANY, title = 'DeepLabCut2.0 - Refinement ToolBox',
                            size = wx.Size(self.gui_size), pos = wx.DefaultPosition, style = wx.RESIZE_BORDER|wx.DEFAULT_FRAME_STYLE|wx.TAB_TRAVERSAL )
        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetStatusText("")
        self.Bind(wx.EVT_CHAR_HOOK, self.OnKeyPressed)

        self.SetSizeHints(wx.Size(self.gui_size)) #  This sets the minimum size of the GUI. It can scale now!
###################################################################################################################################################

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, config,self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)
        # self.choice_panel.SetupScrolling(scroll_x=True, scroll_y=True, scrollToTop=False)
        # self.choice_panel.SetupScrolling(scroll_x=True, scrollToTop=False)
        vSplitter.SplitVertically(self.image_panel,self.choice_panel, sashPosition=self.gui_size[0]*0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)
        self.load = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Load labels")
        widgetsizer.Add(self.load , 1, wx.ALL, 15)
        self.load.Bind(wx.EVT_BUTTON, self.browseDir)

        self.prev = wx.Button(self.widget_panel, id=wx.ID_ANY, label="<<Previous")
        widgetsizer.Add(self.prev , 1, wx.ALL, 15)
        self.prev.Bind(wx.EVT_BUTTON, self.prevImage)
        self.prev.Enable(False)

        self.next = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Next>>")
        widgetsizer.Add(self.next , 1, wx.ALL, 15)
        self.next.Bind(wx.EVT_BUTTON, self.nextImage)
        self.next.Enable(False)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        widgetsizer.Add(self.help , 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)
        self.help.Enable(True)

        self.zoom = wx.ToggleButton(self.widget_panel, label="Zoom")
        widgetsizer.Add(self.zoom , 1, wx.ALL, 15)
        self.zoom.Bind(wx.EVT_TOGGLEBUTTON, self.zoomButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.zoom.Enable(False)

        self.home = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Home")
        widgetsizer.Add(self.home , 1, wx.ALL,15)
        self.home.Bind(wx.EVT_BUTTON, self.homeButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.home.Enable(False)

        self.pan = wx.ToggleButton(self.widget_panel, id=wx.ID_ANY, label="Pan")
        widgetsizer.Add(self.pan , 1, wx.ALL, 15)
        self.pan.Bind(wx.EVT_TOGGLEBUTTON, self.panButton)
        self.widget_panel.SetSizer(widgetsizer)
        self.pan.Enable(False)

        self.lock = wx.CheckBox(self.widget_panel, id=wx.ID_ANY, label="Lock View")
        widgetsizer.Add(self.lock, 1, wx.ALL, 15)
        self.lock.Bind(wx.EVT_CHECKBOX, self.lockChecked)
        self.widget_panel.SetSizer(widgetsizer)
        self.lock.Enable(False)

        self.save = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Save")
        widgetsizer.Add(self.save , 1, wx.ALL, 15)
        self.save.Bind(wx.EVT_BUTTON, self.saveDataSet)
        self.save.Enable(False)

        widgetsizer.AddStretchSpacer(15)
        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit , 1, wx.ALL|wx.ALIGN_RIGHT, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)
        self.widget_panel.Layout()

###############################################################################################################################
        # Variable initialization
        self.currentDirectory = os.getcwd()
        self.index = []
        self.iter = []
        self.threshold = []
        self.file = 0
        self.updatedCoords = []
        self.drs = []
        self.cfg = auxiliaryfunctions.read_config(config)
        self.humanscorer = self.cfg['scorer']
        self.move2corner = self.cfg['move2corner']
        self.center = self.cfg['corner2move2']
        self.colormap = plt.get_cmap(self.cfg['colormap'])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg['dotsize']
        self.alpha = self.cfg['alphavalue']
        self.iterationindex = self.cfg['iteration']
        self.project_path=self.cfg['project_path']
        self.bodyparts = self.cfg['bodyparts']
        self.threshold = 0.1
        self.img_size = (10,6)# (imgW, imgH)  # width, height in inches.
        self.preview = False
        self.view_locked=False
        # Workaround for MAC - xlim and ylim changed events seem to be triggered too often so need to make sure that the
        # xlim and ylim have actually changed before turning zoom off
        self.prezoom_xlim=[]
        self.prezoom_ylim=[]
        from deeplabcut.utils import auxfun_multianimal
        self.individual_names,self.uniquebodyparts,self.multianimalbodyparts = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
        # self.choiceBox,self.visualization_rdb = self.choice_panel.addRadioButtons()
        self.Colorscheme = visualization.get_cmap(len(self.individual_names),self.cfg['colormap'])