Exemple #1
0
def transformer_reID(
    config,
    videos,
    videotype=".mp4",
    shuffle=1,
    trainingsetindex=0,
    track_method="ellipse",
    n_tracks=None,
    n_triplets=1000,
    train_epochs=100,
    train_frac=0.8,
    modelprefix="",
):
    """
    Enables tracking with transformer.

    Substeps include:

    - Mines triplets from tracklets in videos (from another tracker)
    - These triplets are later used to tran a transformer with triplet loss
    - The transformer derived appearance similarity is then used as a stitching loss when tracklets are
    stitched during tracking.

    Outputs: The tracklet file is saved in the same folder where the non-transformer tracklet file is stored.

    Parameters
    ----------
    config: string
        Full path of the config.yaml file as a string.

    videos: list
        A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.

    videotype: (optional) str
        extension for the video file

    shuffle : int, optional
        which shuffle to use

    trainingsetindex : int. optional
        which training fraction to use, identified by its index

    track_method: str, optional
        track method from which tracklets are sampled

    n_tracks: int
        number of tracks to be formed in the videos.
        TODO: handling videos with different number of tracks

    n_triplets: (optional) int
        number of triplets to be mined from the videos

    train_epochs: (optional), int
        number of epochs to train the transformer

    train_frac: (optional), fraction
        fraction of triplets used for training/testing of the transformer

    Examples
    --------

    Training model for one video based on ellipse-tracker derived tracklets
    >>> deeplabcut.transformer_reID(path_config_file,[''/home/alex/video.mp4'],track_method="ellipse")

    --------

    """
    import deeplabcut
    import os
    from deeplabcut.utils import auxiliaryfunctions

    # calling create_tracking_dataset, train_tracking_transformer, stitch_tracklets

    cfg = auxiliaryfunctions.read_config(config)

    DLCscorer, _ = deeplabcut.utils.auxiliaryfunctions.GetScorerName(
        cfg,
        shuffle=shuffle,
        trainFraction=cfg["TrainingFraction"][trainingsetindex],
        modelprefix=modelprefix,
    )

    deeplabcut.pose_estimation_tensorflow.create_tracking_dataset(
        config,
        videos,
        track_method,
        videotype=videotype,
        shuffle=shuffle,
        trainingsetindex=trainingsetindex,
        modelprefix=modelprefix,
        n_triplets=n_triplets,
    )

    (
        trainposeconfigfile,
        testposeconfigfile,
        snapshotfolder,
    ) = deeplabcut.return_train_network_path(
        config,
        shuffle=shuffle,
        modelprefix=modelprefix,
        trainingsetindex=trainingsetindex,
    )

    deeplabcut.pose_tracking_pytorch.train_tracking_transformer(
        config,
        DLCscorer,
        videos,
        videotype=videotype,
        train_frac=train_frac,
        modelprefix=modelprefix,
        train_epochs=train_epochs,
        ckpt_folder=snapshotfolder,
    )

    transformer_checkpoint = os.path.join(snapshotfolder,
                                          f"dlc_transreid_{train_epochs}.pth")

    if not os.path.exists(transformer_checkpoint):
        raise FileNotFoundError(
            f"checkpoint {transformer_checkpoint} not found")

    deeplabcut.stitch_tracklets(
        config,
        videos,
        videotype=videotype,
        shuffle=shuffle,
        trainingsetindex=trainingsetindex,
        track_method=track_method,
        modelprefix=modelprefix,
        n_tracks=n_tracks,
        transformer_checkpoint=transformer_checkpoint,
    )
Exemple #2
0
"""
trainIndices, testIndices=deeplabcut.mergeandsplit(path_config_file,trainindex=0,uniform=True)
deeplabcut.create_training_dataset(path_config_file,Shuffles=[2],trainIndices=trainIndices,testIndices=testIndices)
deeplabcut.create_training_dataset(path_config_file,Shuffles=[3],trainIndices=trainIndices,testIndices=testIndices)
for shuffle in [2,3]:
	if shuffle==3:
		posefile=os.path.join(cfg['project_path'],'dlc-models/iteration-'+str(cfg['iteration'])+'/'+ cfg['Task'] + cfg['date'] + '-trainset' + str(int(cfg['TrainingFraction'][0] * 100)) + 'shuffle' + str(shuffle),'train/pose_cfg.yaml')

		DLC_config=deeplabcut.auxiliaryfunctions.read_plainconfig(posefile)
		DLC_config['dataset_type']='tensorpack'
		deeplabcut.auxiliaryfunctions.write_plainconfig(posefile,DLC_config)
"""

for shuffle in 1 + np.arange(6):

    posefile, _, _ = deeplabcut.return_train_network_path(path_config_file,
                                                          shuffle=shuffle)

    if shuffle % 3 == 1:  # imgaug
        edits = {"rotation": 180, "motion_blur": True}
        DLC_config = deeplabcut.auxiliaryfunctions.edit_config(posefile, edits)

    elif shuffle % 3 == 0:  # Tensorpack:
        edits = {"rotation": 180, "noise_sigma": 0.01}
        DLC_config = deeplabcut.auxiliaryfunctions.edit_config(posefile, edits)

    print("TRAIN NETWORK", shuffle)
    deeplabcut.train_network(
        path_config_file,
        shuffle=shuffle,
        saveiters=10000,
        displayiters=200,
    print("Video created.")

    print("Convert detections to tracklets...")
    deeplabcut.convert_detections2tracklets(
        config_path, [new_video_path], "mp4", track_method=TESTTRACKER
    )
    print("Tracklets created...")

    ### adding it here
    modelprefix = ""
    (
        trainposeconfigfile,
        testposeconfigfile,
        snapshotfolder,
    ) = deeplabcut.return_train_network_path(
        config_path, shuffle=1, modelprefix=modelprefix, trainingsetindex=0
    )

    print("Creating triplet dataset")

    deeplabcut.pose_estimation_tensorflow.create_tracking_dataset(
        config_path, [new_video_path], TESTTRACKER, videotype="mp4",
    )

    train_epochs = 10
    train_frac = 0.8

    print("Training transformer")

    deeplabcut.pose_tracking_pytorch.train_tracking_transformer(
        config_path,