def split_randomly(db_file,
                   dataset,
                   config,
                   args,
                   test_clips=[],
                   balance_bins=True):
    # split data randomly such that a clip is only in one dataset
    # have tried many ways to split i.e. location and cameras found this is simplest
    # and the results are the same
    train = Dataset(db_file, "train", config)
    train.enable_augmentation = True
    validation = Dataset(db_file, "validation", config)
    test = Dataset(db_file, "test", config)
    test_c = get_test_set_camera(dataset, test_clips, args.date)
    test_cameras = [test_c]
    validate_cameras = []
    train_cameras = []
    for label in dataset.labels:
        existing_test_count = len(test.tracks_by_label.get(label, []))
        train_c, validate_c, test_c = split_label(
            dataset, label, existing_test_count=existing_test_count)
        if train_c is not None:
            train_cameras.append(train_c)
        if validate_c is not None:
            validate_cameras.append(validate_c)
        if test_c is not None:
            test_cameras.append(test_c)

    add_camera_tracks(dataset.labels, train, train_cameras, balance_bins)
    add_camera_tracks(dataset.labels, validation, validate_cameras,
                      balance_bins)
    add_camera_tracks(dataset.labels, test, test_cameras, balance_bins)
    return train, validation, test
Beispiel #2
0
def split_dataset_by_cameras(db, dataset, config, args):
    validation_percent = 0.3
    train = Dataset(db, "train", config)
    train.enable_augmentation = True
    validation = Dataset(db, "validation", config)

    train_cameras = []
    cameras = list(dataset.cameras_by_id.values())
    camera_count = len(cameras)
    num_validate_cameras = max(
        MIN_VALIDATE_CAMERAS, round(camera_count * validation_percent)
    )

    wallaby, wallaby_validate = split_wallaby_cameras(dataset, cameras)
    if wallaby:
        train_cameras.append(wallaby)
    # has all the rabbits so put in training
    rabbits = dataset.cameras_by_id.get("ruru19w44a-[-36.03915 174.51675]")
    if rabbits:
        cameras.remove(rabbits)
        train_cameras.append(rabbits)

    validate_cameras, cameras = diverse_validation(
        cameras, dataset.labels, num_validate_cameras
    )
    if wallaby_validate:
        validate_cameras.append(wallaby_validate)
    train_cameras.extend(cameras)

    add_camera_tracks(dataset.labels, train, train_cameras)
    add_camera_tracks(dataset.labels, validation, validate_cameras)

    return train, validation